## 0. config

In [1]:
from simplecv.util import config
from simplecv.core.config import AttrDict

In [2]:
config_path = 'RSSGL.RSSGL_Pavia'

cfg = config.import_config(config_path)

In [3]:
cfg = AttrDict.from_dict(cfg)

In [4]:
opts = ['train.save_ckpt_interval_epoch', '9999']
opts

['train.save_ckpt_interval_epoch', '9999']

In [5]:
cfg.update_from_list(opts)

In [6]:
cfg

{'model': {'type': 'RSSGL',
  'params': {'in_channels': 103,
   'num_classes': 9,
   'block_channels': (96, 128, 192, 256),
   'inner_dim': 128,
   'reduction_ratio': 1.0}},
 'data': {'train': {'type': 'NewPaviaLoader',
   'params': {'training': True,
    'num_workers': 4,
    'image_mat_path': './pavia/PaviaU.mat',
    'gt_mat_path': './pavia/PaviaU_gt.mat',
    'sample_percent': 0.01,
    'batch_size': 10}},
  'test': {'type': 'NewPaviaLoader',
   'params': {'training': False,
    'num_workers': 4,
    'image_mat_path': './pavia/PaviaU.mat',
    'gt_mat_path': './pavia/PaviaU_gt.mat',
    'sample_percent': 0.01,
    'batch_size': 10}}},
 'optimizer': {'type': 'sgd',
  'params': {'momentum': 0.9, 'weight_decay': 0.001}},
 'learning_rate': {'type': 'poly',
  'params': {'base_lr': 0.005, 'power': 0.8, 'max_iters': 1000}},
 'train': {'forward_times': 1,
  'num_iters': 2,
  'eval_per_epoch': True,
  'summary_grads': False,
  'summary_weights': False,
  'eval_after_train': True,
  'resume_

## 1. model

In [7]:
import torch
import torch.nn as nn
import numpy as np
from module import RSSGL
from simplecv.module.model_builder import make_model
from simplecv import registry

In [8]:
registry.MODEL

{'resnet18': <function simplecv.module._resnets.resnet18(pretrained=False, progress=True, **kwargs)>,
 'resnet34': <function simplecv.module._resnets.resnet34(pretrained=False, progress=True, **kwargs)>,
 'resnet50': <function simplecv.module._resnets.resnet50(pretrained=False, progress=True, **kwargs)>,
 'resnet101': <function simplecv.module._resnets.resnet101(pretrained=False, progress=True, **kwargs)>,
 'resnext50_32x4d': <function simplecv.module._resnets.resnext50_32x4d(pretrained=False, progress=True, **kwargs)>,
 'resnext101_32x4d': <function simplecv.module._resnets.resnext101_32x4d(pretrained=False, progress=True, **kwargs)>,
 'resnext101_32x8d': <function simplecv.module._resnets.resnext101_32x8d(pretrained=False, progress=True, **kwargs)>,
 'resnet_encoder': simplecv.module.resnet.ResNetEncoder,
 'RSSGL': module.RSSGL.RSSGL}

In [9]:
config = cfg['model']
model_type = config['type']
model_type

'RSSGL'

In [10]:
model = registry.MODEL[model_type](config['params'])

In [11]:
model

RSSGL(
  (feature_ops): ModuleList(
    (0): Sequential(
      (0): Conv2d(103, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): GroupNorm(4, 96, eps=1e-05, affine=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Si_ConvLSTM(
        (cell0): ConvLSTMCell(
          (Wxi): Conv3d(1, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2))
          (Whi): Conv3d(4, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=False)
          (Wxf): Conv3d(1, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2))
          (Whf): Conv3d(4, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=False)
          (Wxc): Conv3d(1, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2))
          (Whc): Conv3d(4, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=False)
          (Wxo): Conv3d(1, 4, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2))
          (Who): Conv3d(4, 4, kerne

In [12]:
model.to(torch.device('cuda'))
model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))

## 2. data

In [13]:
from simplecv.data.data_loader import make_dataloader
import data.dataloader

In [14]:
config = cfg['data']['train']
config

{'type': 'NewPaviaLoader',
 'params': {'training': True,
  'num_workers': 4,
  'image_mat_path': './pavia/PaviaU.mat',
  'gt_mat_path': './pavia/PaviaU_gt.mat',
  'sample_percent': 0.01,
  'batch_size': 10}}

In [15]:
dataloader_type = config['type']
dataloader_type

'NewPaviaLoader'

In [16]:
registry.DATALOADER

{'NewPaviaLoader': data.dataloader.NewPaviaLoader}

In [17]:
config['params']

{'training': True,
 'num_workers': 4,
 'image_mat_path': './pavia/PaviaU.mat',
 'gt_mat_path': './pavia/PaviaU_gt.mat',
 'sample_percent': 0.01,
 'batch_size': 10}

In [18]:
traindata_loader = registry.DATALOADER[dataloader_type](config['params'])

In [19]:
testdata_loader = make_dataloader(cfg['data']['test']) if 'test' in cfg['data'] else None

In [20]:
len(traindata_loader)

10

In [21]:
for idx, (im, mask, w) in enumerate(traindata_loader):
    print(im.shape, mask.shape, w.shape)

torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])
torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])


In [22]:
for idx, (im, mask, w) in enumerate(testdata_loader):
    print(im.shape, mask.shape, w.shape)

torch.Size([1, 103, 624, 352]) torch.Size([1, 624, 352]) torch.Size([1, 624, 352])


## 3. optimizer

In [23]:
from simplecv.opt.learning_rate import make_learningrate
import numpy as np
from simplecv.util import registry
from simplecv.interface import LearningRateBase
import math
from simplecv.opt.optimizer import make_optimizer

In [24]:
config = cfg['learning_rate']
config

{'type': 'poly', 'params': {'base_lr': 0.005, 'power': 0.8, 'max_iters': 1000}}

In [25]:
lr_type = config['type']
lr_type

'poly'

In [26]:
registry.LR

{'multistep': simplecv.opt.learning_rate.MultiStepLearningRate,
 'poly': simplecv.opt.learning_rate.PolyLearningRate,
 'cosine': simplecv.opt.learning_rate.CosineAnnealingLearningRate}

In [27]:
lr_module = registry.LR[lr_type]
lr_schedule = lr_module(**config['params'])

In [28]:
lr_schedule.__dict__

{'_base_lr': 0.005, 'power': 0.8, 'max_iters': 1000}

In [29]:
lr_schedule.base_lr

0.005

In [30]:
cfg['optimizer']['params']['lr'] = lr_schedule.base_lr

In [31]:
cfg['optimizer']['params']

{'momentum': 0.9, 'weight_decay': 0.001, 'lr': 0.005}

In [32]:
config = cfg['optimizer']
params = model.parameters()

In [33]:
registry.OPT

{'sgd': torch.optim.sgd.SGD,
 'adam': torch.optim.adam.Adam,
 'fused_adam': apex.optimizers.fused_adam.FusedAdam}

In [34]:
opt_type = config['type']
opt_type

'sgd'

In [35]:
optimizer = registry.OPT[opt_type](params=params, **config['params'])

In [36]:
optimizer.simplecv_config = config

In [37]:
optimizer

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.005
    momentum: 0.9
    nesterov: False
    weight_decay: 0.001
)

In [38]:
from simplecv.core import trainer

In [39]:
tl = trainer.Launcher(model_dir='./log/pavia/SSDGL/1.0_poly', 
                      model=model, 
                      optimizer=optimizer, 
                      lr_schedule=lr_schedule)

## 4. GCLAM

In [40]:
import torch.nn.functional as F
from simplecv.interface import CVModule
from torch.autograd import Variable

In [41]:
def conv3x3_gn_relu(in_channel, out_channel, num_group):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, 3, 1, 1), 
        nn.GroupNorm(num_group, out_channel), 
        nn.ReLU(inplace=True))


def gn_relu(in_channel, num_group):
    return nn.Sequential(
        nn.GroupNorm(num_group, in_channel), 
        nn.ReLU(inplace=True))


def downsample2x(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, 3, 2, 1), 
        nn.ReLU(inplace=True))

In [42]:
dataset_path = "SSDGL.SSDGL_1_0_pavia"

In [43]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4

        # self.padding = int((kernel_size - 1) / 2)
        self.padding = tuple((int((i-1)/2) for i in kernel_size))

        self.Wxi = nn.Conv3d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv3d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv3d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv3d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv3d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv3d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv3d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv3d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
        ch = co * torch.tanh(cc)
        return ch, cc

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1], shape[2])).cuda()
            self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1], shape[2])).cuda()
            self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1], shape[2])).cuda()
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
            assert shape[2] == self.Wci.size()[4], 'Input Dimension Mismatched!'
        return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1], shape[2])).cuda(),
                Variable(torch.zeros(batch_size, hidden, shape[0], shape[1], shape[2])).cuda())

In [44]:
class ConvLSTM(nn.Module):
    # input_channels corresponds to the first input feature map
    # hidden state is a list of succeeding lstm layers
    def __init__(self, input_channels, hidden_channels, kernel_size, step=8, effective_step=7):
        super(ConvLSTM, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            name_reverse = 'cell_reverse{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            cell_reverse = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            setattr(self, name, cell)
            setattr(self, name_reverse, cell_reverse)
            self._all_layers.append(cell)
            self._all_layers.append(cell_reverse)
        # self.recover_conv = nn.Conv2d(self.hidden_channels[0] * self.hidden_channels[1], 
        #                               self.hidden_channels[-1] * 8, 1)
    
    
    def forward(self, input):
        internal_state = []
        internal_state_reverse = []
        outputs = []
        outputs_reverse = []
        a = input.squeeze()
        b = int(len(a) / 8)
        for i in range(self.num_layers):
            name = "cell{}".format(i)
            name_reverse = "cell_reverse{}".format(i)
            if i == 0:
                for step in range(self.step):
                    if step == 0:
                        x_reverse = input[:, -(step + 1) * b:]
                    else:
                        x_reverse = input[:, -(step + 1) * b: -step * b]
                    x_reverse = x_reverse.unsqueeze(dim=1)  # (1, 1, 12, 624, 352), (1, 1, 16, 312, 176)...
            
                    x = input[:, step * b:(step + 1) * b, :, :]  # (1, 12, 624, 352), (1, 16, 312, 176)...
                    x = x.unsqueeze(dim=1)  # (1, 1, 12, 624, 352), (1, 1, 16, 312, 176)...
            
                    bsize, _, dimension, height, width = x.size()
                
                    if step == 0:
                        (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i], 
                                                                 shape=(dimension, height, width))
                        internal_state.append((h, c))
                    
                        (h_reverse, c_reverse) = getattr(self, name_reverse).init_hidden(batch_size=bsize, 
                                                                                         hidden=self.hidden_channels[i], 
                                                                                         shape=(dimension, height, width))
                        internal_state_reverse.append((h_reverse, c_reverse))
                    
                    # do forward
                    (h, c) = internal_state[i]
                    (h_reverse, c_reverse) = internal_state_reverse[i]
                
                    x, new_c = getattr(self, name)(x, h, c)
                    internal_state[i] = (x, new_c)
                
                    x_reverse, new_c_reverse = getattr(self, name_reverse)(x_reverse, h_reverse, c_reverse)
                    internal_state_reverse[i] = (x_reverse, new_c_reverse)
                
                    outputs.append(x)
                    outputs_reverse.insert(0, x_reverse)
                if self.num_layers == 1:
                    result = outputs[-1] + outputs_reverse[-1]
                    result = result[:, 0]
                    for i in range(self.hidden_channels[i] - 1):
                        result = torch.cat([result, x[:, i + 1]], dim=1)
                    return result
            else:
                input = torch.cat([outputs[j] + outputs_reverse[j] for j in range(self.step)], dim=1)
                b = self.hidden_channels[i - 1]
                outputs = []
                outputs_reverse = []
                for step in range(self.step):
                    if step == 0:
                        x_reverse = input[:, -(step + 1) * b:]
                    else:
                        x_reverse = input[:, -(step + 1) * b: -step * b]
            
                    x = input[:, step * b:(step + 1) * b]  # (1, 8, 12, 624, 352), (1, 8, 16, 312, 176)...
                
                    bsize, _, dimension, height, width = x.size()
                
                    if step == 0:
                        (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i], 
                                                                 shape=(dimension, height, width))
                        internal_state.append((h, c))
                    
                        (h_reverse, c_reverse) = getattr(self, name_reverse).init_hidden(batch_size=bsize, 
                                                                                         hidden=self.hidden_channels[i], 
                                                                                         shape=(dimension, height, width))
                        internal_state_reverse.append((h_reverse, c_reverse))
                    
                    # do forward
                    (h, c) = internal_state[i]
                    (h_reverse, c_reverse) = internal_state_reverse[i]
                
                    x, new_c = getattr(self, name)(x, h, c)
                    internal_state[i] = (x, new_c)
                
                    x_reverse, new_c_reverse = getattr(self, name_reverse)(x_reverse, h_reverse, c_reverse)
                    internal_state_reverse[i] = (x_reverse, new_c_reverse)
                
                    outputs.append(x)
                    outputs_reverse.insert(0, x_reverse)
                if i == self.num_layers - 1:
                    result = outputs[-1] + outputs_reverse[-1]
                    result = result[:, 0]
                    for i in range(self.hidden_channels[i] - 1):
                        result = torch.cat([result, x[:, i + 1]], dim=1)
                    return result 

In [45]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):  # 需要指定输入通道数
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # (batch_size, num_channels, 1, 1)
        
        self.max_pool = nn.AdaptiveMaxPool2d(1)  # (batch_size, num_channels, 1, 1)
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        residual = x
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        
        y = x * out.view(out.size(0), out.size(1), 1, 1)
        
        y = y + residual
        return y

In [46]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        residual = x  # (batch_size, num_channels, h, w)
        avg_out = torch.mean(x, dim=1, keepdim=True)  # (batch_size, 1, h, w)
        max_out = torch.max(x, dim=1, keepdim=True)
        
        out = torch.cat([avg_out, max_out], dim=1)  # (batch_size, 2, h, w)
        out1 = self.conv1(out)
        out2 = self.relu1(out1)
        out = self.sigmoid(out2)  # (batch_size, 1, h, w)
        
        y = x * out.view(out.size(0), 1, out.size(-2), out.size(-1))
        y = y + residual
        return  y

In [47]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, planes):  # 需要指定输入通道数
        super(BasicBlock, self).__init__()
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        out = torch.cat([self.ca(x), self.sa(x)], dim=1)  # (batch_size, planes * 2, h, w)
        return out

In [48]:
def repeat_block(block_channel, r, n):
    cl_channel = int(block_channel / 8)
    cl2_channel = int(cl_channel / 2)
    gn_a = int(block_channel / 2)
    # 注意分析每一个组成部分的输出形状，其中ConvLSTM将输入在光谱维度平均分为8份作为每一个时间步长的输入，最终输出维度为原始维度的1/2
    # 而SS注意力机制正好恢复出原始维度
    layers = (nn.Sequential(ConvLSTM(input_channels=1, hidden_channels=[4, 4], 
                                     kernel_size=(3, 5, 5), step=8, effective_step=7).cuda(), 
                            BasicBlock(gn_a), 
                            gn_relu(block_channel, r), ))
    return nn.Sequential(*layers)

In [49]:
config = cfg['model']['params']
config

{'in_channels': 103,
 'num_classes': 9,
 'block_channels': (96, 128, 192, 256),
 'inner_dim': 128,
 'reduction_ratio': 1.0}

In [50]:
class SSDGL(CVModule):
    def __init__(self, config):
        super(SSDGL, self).__init__(config)
        r = int(4 * self.config.reduction_ratio)  # The group number of group normalization: 4
        block1_channels = int(self.config.block_channels[0] * self.config.reduction_ratio / r) * r
        block2_channels = int(self.config.block_channels[1] * self.config.reduction_ratio / r) * r
        block3_channels = int(self.config.block_channels[2] * self.config.reduction_ratio / r) * r
        block4_channels = int(self.config.block_channels[3] * self.config.reduction_ratio / r) * r
        
        self.feature_ops = nn.ModuleList([
            conv3x3_gn_relu(self.config.in_channels, block1_channels, r),  # (batch_size, 96, 624, 352)
            
            repeat_block(block1_channels, r, self.config.num_blocks[0]),   # num_blocks=(1, 1, 1, 1)
            nn.Identity(),
            downsample2x(block1_channels, block2_channels),  # (batch_size, 128, 312, 176)
            
            repeat_block(block2_channels, r, self.config.num_blocks[1]),
            nn.Identity(),
            downsample2x(block2_channels, block3_channels),  # (batch_size, 192, 156, 88)
            
            repeat_block(block3_channels, r, self.config.num_blocks[2]),
            nn.Identity(),
            downsample2x(block3_channels, block4_channels),  # (batch_size, 256, 78, 44)
            
            repeat_block(block4_channels, r, self.config.num_blocks[3]),
            nn.Identity(), 
        ])
        inner_dim = int(self.config.inner_dim * self.config.reduction_ratio)
        
        self.BasicBlock_list = nn.ModuleList([
            BasicBlock(inner_dim), 
            BasicBlock(inner_dim),
            BasicBlock(inner_dim),
            BasicBlock(inner_dim),
        ])
        self.spation_list = nn.ModuleList([
            SpatialAttention(),
            SpatialAttention(),
            SpatialAttention(),
            SpatialAttention(),
        ])
        self.reduce_1x1convs = nn.ModuleList([
            nn.Conv2d(block1_channels, inner_dim, 1), 
            nn.Conv2d(block2_channels, inner_dim, 1), 
            nn.Conv2d(block3_channels, inner_dim, 1), 
            nn.Conv2d(block4_channels, inner_dim, 1),
        ])
        self.fuse_3x3convs = nn.ModuleList([
            conv3x3_gn_relu(inner_dim, inner_dim, r), 
            conv3x3_gn_relu(inner_dim, inner_dim, r), 
            conv3x3_gn_relu(inner_dim, inner_dim, r), 
            nn.Conv2d(inner_dim, self.config.in_channels, 3, 1, 1), 
        ])
        
        self.cls_pred_conv = nn.Conv2d(self.config.in_channels, self.config.num_classes, 1)
        
    def top_down(self, top, lateral):
        top2x = F.interpolate(top, scale_factor=2.0, mode='bilinear')
        return top2x + lateral
    
    def forward(self, x, y=None, train_inds=None, **kwargs):
        feat_list = []
        for op in self.feature_ops:
            x = op(x)
            
            if isinstance(op, nn.Identity):
                feat_list.append(x)
        inner_feat_list = [self.reduce_1x1convs[i](feat) for i, feat in enumerate(feat_list)]
        
        inner_feat_list.reverse()  # [(batch_size, 128, 78, 44), (batch_size, 128, 156, 88), ...]
        out_feat_list = [self.fuse_3x3convs[0](inner_feat_list[0])]  # (batch_size, 128, 78, 44)
        for i in range(len(inner_feat_list) - 1):
            inner = slef.top_down(out_feat_list[i], inner_feat_list[i + 1])
            out = self.fuse_3x3convs[i + 1](inner)
            out_feat_list.append(out)
        final_feat = out_feat_list[-1]  # (batch_size, 103, 624, 352) This is the final feature space!!!
        
        logit = self.cls_pred_conv(final_feat)  # (batch_size, 9, 624, 352)
        if self.training:
            loss_dict = {'cls_loss': self.loss(logit, y, train_inds, final_feat)}
            return loss_dict
        
        return torch.softmax(logit, dim=1)  # (batch_size, 9, 624, 352)
    
    def loss(self, x, y, train_inds, final_feat):
        beta = 0.9999
        if dataset_path =="SSDGL.SSDGL_1_0_pavia":
            cls_num_list = [6631, 18649, 2099, 3064, 1345, 5029, 1330, 3682, 947]
        elif dataset_path =="SSDGL.SSDGL_1_0_Indianpine":
            cls_num_list = [46, 1428, 830, 237, 483, 730, 28, 478, 20, 972, 2455, 593, 205, 1265, 386, 93]
        elif dataset_path == "SSDGL.SSDGL_1_0_salinas":
            cls_num_list = [2009, 3726, 1976, 1394, 2678, 3959, 3579, 11271, 6203, 3278, 1068,1927, 916, 1070, 7268, 1807]
        elif dataset_path == "SSDGL.SSDGL_1_0_HOS":
            cls_num_list = [1251, 1254, 697, 1244, 1242, 325,1268, 1244, 1252, 1227, 1235, 1233, 469, 428, 660]
        else:
            print("no cls_num_list")
        effective_num = 1.0 - np.power(beta, cls_num_list)
        per_cls_weights = (1.0 - beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
        # 有标签的样本中，数量越少的类别，它的损失权重越大
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
        
        # (1, 624, 352) 注意!!!这里的程序十分关键，要好好分析为什么这样计算损失?
        weighted_losses = F.cross_entropy(x, y.long() - 1, ignore_index=-1, reduction='none', 
                                          weight=per_cls_weights)
        #print(weighted_losses.size())
        weighted_cross_entropy_losses = weighted_losses.mul_(train_inds).sum() / train_inds.sum()
        #print(weighted_cross_entropy_losses, train_inds.sum())
        
        losses = weighted_cross_entropy_losses + self.statistical_loss(y, train_inds, final_feat)
        return losses
    
    def statistical_loss(self, y, train_inds, final_feat):  
        # y: (1, 624, 352), train_inds: (1, 624, 352), final_feat: (1, 103, 624, 352)
        lamb = 0.01
        delta = 0.
        y = y.squeeze()  # ground truth: (624, 352)
        train_inds = train_inds.squeeze()  # (624, 352)
        final_feat = final_feat.squeeze()  # (103, 624, 352)
        
        cls_list = torch.unique(y)  # (0, 1, 2, .., 9)
        num_cls = len(cls_list) - 1  # 9
        num_train = int(train_inds.sum())  # 90
        feat_dimension = final_feat.size()[0]  # 103
        
        location = torch.where(train_inds == 1.)
        label = y[location]
        
        feat_dict_per_class = dict()
        for i in range(1, num_cls+1):
            feat_inds = torch.where(label==i)
            feat_dict_per_class[i] = final_feat[:, location[0][feat_inds], location[1][feat_inds]]
            
        ck = dict()
        for i in range(1, num_cls+1):
            ck[i] = feat_dict_per_class[i].mean(dim=1).unsqueeze(dim=1)
            
        variance_loss = torch.tensor(0.).cuda()
        for i in range(1, num_cls+1):
            zj_ck = feat_dict_per_class[i] - ck[i]  # (103, num_train)
            variance_loss += zj_ck.mul(zj_ck).sum() / (zj_ck.size()[1] - 1)
        variance_loss = variance_loss / num_cls
        
        diver_loss = torch.tensor(0.).cuda()
        for k in range(1, num_cls+1):
            Sk = torch.zeros(feat_dimension, feat_dimension).cuda()
            zj_ck = feat_dict_per_class[k] - ck[k]
            nk = zj_ck.size()[1]
            for i in range(nk):
                Sk += torch.mm(zj_ck[:, i].unsqueeze(dim=1), zj_ck[:, i].unsqueeze(dim=0))
            for t in range(k+1, num_cls+1):
                St = torch.zeros(feat_dimension, feat_dimension).cuda()
                zj_ct = feat_dict_per_class[t] - ck[t]
                nt = zj_ct.size()[1]
                for j in range(nt):
                    St += torch.mm(zj_ct[:, j].unsqueeze(dim=1), zj_ct[:, j].unsqueeze(dim=0))
                ck_ct = ck[k] - ck[t]
                diver_loss += delta - torch.mm(torch.mm(ck_ct.transpose(1, 0), torch.inverse(Sk + St)), 
                                               ck_ct).squeeze() * (nk*nt - 2*nk*nt/(nk+nt))
        diver_loss = diver_loss * lamb
                
        return variance_loss + diver_loss
    
    def set_defalut_config(self):
        # pavia
        self.config.update(dict(
            in_channels=103, 
            num_classes=9, 
            block_channels=(96, 128, 192, 256), 
            num_blocks=(1, 1, 1, 1), 
            inner_dim=128, 
            reduction_ratio=1.0, 
        ))

### 5. 分步测试

In [51]:
train_inds_list = []
for idx, (im, mask, w) in enumerate(traindata_loader):
    train_inds_list.append(w)
    
train_inds = train_inds_list[0]

In [52]:
y = mask.squeeze()
train_inds = train_inds.squeeze()

In [53]:
final_feat = torch.randn(103, 624, 352).cuda()

In [54]:
cls_list = torch.unique(y)

In [55]:
cls_list

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

In [56]:
num_cls = len(cls_list) - 1
num_train = int(train_inds.sum())

In [57]:
num_train

90

In [58]:
num_cls

9

In [59]:
location = torch.where(train_inds == 1.)
label = y[location]

In [60]:
location

(tensor([ 28,  50,  68, 102, 105, 107, 146, 149, 151, 154, 162, 164, 168, 173,
         183, 185, 194, 194, 200, 202, 202, 204, 213, 216, 223, 227, 232, 240,
         247, 254, 255, 274, 280, 295, 304, 307, 307, 307, 315, 315, 318, 325,
         325, 328, 334, 335, 337, 337, 338, 338, 345, 347, 350, 351, 357, 358,
         360, 366, 366, 372, 373, 395, 406, 406, 429, 433, 438, 442, 442, 451,
         454, 463, 468, 492, 518, 523, 529, 534, 535, 542, 551, 553, 564, 565,
         566, 569, 588, 598, 601, 604]),
 tensor([ 78, 122, 149, 155,  66,  28, 191, 128, 177,  50,  51, 115, 134, 132,
         121, 149, 138, 197,  14, 129, 140, 155,  38,  75, 148, 152,  21, 177,
          39, 163,  34, 204, 202, 190, 205,   2,  15, 187, 121, 208, 200, 110,
         132,  14,  79, 134, 123, 230,  17,  80, 169, 142, 164, 139, 144, 154,
         154,  31, 143,  38, 156,  72,  61, 295,  74,  90, 139,  10, 142, 191,
         184,  86,  14,  62,  88, 203, 164, 145, 141,   9, 238, 156, 172,  68,
          6

In [61]:
label

tensor([1., 1., 4., 1., 2., 2., 4., 5., 1., 8., 8., 9., 5., 5., 9., 5., 5., 1.,
        1., 9., 5., 5., 8., 8., 5., 5., 4., 5., 1., 9., 4., 6., 6., 6., 6., 3.,
        3., 6., 7., 6., 6., 7., 7., 3., 4., 9., 7., 6., 3., 4., 6., 7., 6., 7.,
        7., 7., 7., 3., 7., 3., 9., 3., 3., 1., 3., 9., 8., 9., 8., 9., 9., 8.,
        3., 2., 8., 2., 1., 4., 4., 1., 2., 4., 2., 8., 8., 4., 2., 2., 2., 2.])

In [62]:
feat_dict_per_class = dict()
train_per_class = []
for i in range(1, num_cls+1):
    feat_inds = torch.where(label==i)[0]
    train_per_class.append(len(feat_inds))
    feat_dict_per_class[i] = final_feat[:, location[0][feat_inds], location[1][feat_inds]]

In [63]:
feat = feat_dict_per_class[1]
for i in range(1, num_cls):
    feat = torch.cat([feat, feat_dict_per_class[i+1]], dim=1)

In [64]:
feat.size()

torch.Size([103, 90])

In [65]:
distance = torch.zeros((num_train, num_train), device=torch.device("cuda:0"))

In [66]:
for i in range(num_train):
    anchor = torch.unsqueeze(feat[:, i], 1)
    distance[i, :] = torch.norm(anchor - feat, p=2, dim=0)

In [67]:
anchor = torch.unsqueeze(feat[:, 0], 1)
c = torch.norm(anchor - feat, p=2, dim=0)

In [68]:
c

tensor([ 0.0000, 13.8653, 14.2857, 13.6392, 14.5955, 15.6235, 15.7290, 14.3485,
        14.4833, 15.2859, 14.5438, 13.6323, 13.3772, 15.5281, 15.3592, 13.0736,
        14.3455, 14.9672, 12.5538, 12.8872, 15.4634, 13.8352, 13.9301, 13.2862,
        14.8436, 13.3855, 13.9725, 14.1641, 14.0739, 13.2864, 15.9125, 13.7636,
        13.3306, 14.1956, 13.8764, 14.0034, 12.8655, 14.0232, 14.8358, 13.2024,
        13.5184, 14.5184, 13.2835, 15.0811, 14.1817, 13.9105, 13.1548, 13.8443,
        14.3088, 15.2761, 14.2746, 15.4775, 14.1392, 13.7332, 14.5262, 14.4344,
        13.4065, 14.0440, 14.1902, 13.2463, 15.0064, 13.9796, 14.7549, 14.2155,
        13.2031, 14.0104, 13.1387, 14.3166, 11.7791, 12.5361, 15.1210, 16.0219,
        13.4867, 13.9428, 13.9933, 14.0897, 14.0433, 12.7283, 15.2156, 14.1142,
        13.7853, 13.9876, 12.7503, 15.1429, 13.8540, 15.8069, 14.2308, 13.9807,
        13.1221, 13.6015], device='cuda:0')

In [69]:
anchor = feat[:, 0]

In [70]:
torch.dist(anchor, feat[:, 89], p=2)

tensor(13.6015, device='cuda:0')

In [71]:
distance[45, 50]

tensor(14.7261, device='cuda:0')

In [72]:
distance[50, 45]

tensor(14.7261, device='cuda:0')

In [73]:
train_per_class = [5, 10, 10, 10, 10, 10, 5, 10, 5, 10, 10, 10, 10, 10, 10, 5]
num_cls = 16

In [74]:
acc_index = np.zeros(num_cls).astype(np.int16)
for i in range(num_cls):
    acc_index[i] = sum(train_per_class[:i+1])

In [75]:
acc_index

array([  5,  15,  25,  35,  45,  55,  60,  70,  75,  85,  95, 105, 115,
       125, 135, 140], dtype=int16)

In [76]:
import bisect

bisect_fn = bisect.bisect
position = bisect_fn(acc_index, 5)

In [77]:
position

1

In [78]:
if position == 0:
    left_index = 0
else:
    left_index = acc_index[position-1]
right_index = acc_index[position]
left_index, right_index

(5, 15)

In [79]:
distance[0, :left_index].size(), distance[0, right_index:].size()

(torch.Size([5]), torch.Size([75]))

In [80]:
alpha = 1.0
result = torch.zeros(num_train).cuda()
for i in range(90):
    position = bisect_fn(acc_index, i)
    if position == 0:
        left_index = 0
    else:
        left_index = acc_index[position-1]
    right_index = acc_index[position]
    #print(left_index, right_index)
    intra_dist = distance[i, left_index:right_index]
    inter_dist = torch.cat([distance[i, :left_index], distance[i, right_index:]])
    #print(intra_dist.size(), inter_dist.size())
    #print(torch.topk(intra_dist, k=4)[0], torch.topk(inter_dist, k=4, largest=False)[0])
    positive = (torch.topk(intra_dist, k=4)[0]).sum() / 4
    negative = (torch.topk(inter_dist, k=4, largest=False)[0]).sum() / 4
    result[i] = alpha + positive - negative
    if result[i] < 2.0:
        result[i] = 0.0

In [81]:
result

tensor([2.6971, 2.5929, 2.0472, 0.0000, 3.5177, 3.5436, 3.9423, 3.9382, 3.1814,
        3.3953, 3.6421, 3.6516, 3.7077, 3.9328, 3.1526, 3.2969, 3.6764, 2.7728,
        3.4938, 2.6953, 4.1464, 3.3536, 3.4477, 4.1294, 3.4954, 3.3877, 3.5933,
        3.4263, 3.2547, 3.3897, 3.0738, 3.8040, 3.0195, 3.1208, 3.2751, 2.9413,
        3.7101, 3.6262, 2.9235, 3.3959, 3.6894, 3.7830, 3.4623, 3.4566, 2.9977,
        3.0198, 3.0355, 3.2235, 3.6272, 3.5818, 3.2053, 3.4514, 3.3535, 3.1393,
        2.5718, 2.0719, 3.0380, 2.5163, 2.8320, 2.1598, 3.0866, 3.3315, 3.3731,
        2.9848, 3.4181, 2.9928, 3.1267, 3.0138, 3.6353, 3.3168, 2.9110, 3.1164,
        3.5157, 2.5362, 3.3207, 3.3088, 3.2438, 3.1809, 3.1998, 3.4505, 3.8424,
        3.3883, 3.7903, 2.9894, 3.0427, 2.0168, 2.9740, 3.4137, 2.5979, 2.4448],
       device='cuda:0')

In [82]:
result

tensor([2.6971, 2.5929, 2.0472, 0.0000, 3.5177, 3.5436, 3.9423, 3.9382, 3.1814,
        3.3953, 3.6421, 3.6516, 3.7077, 3.9328, 3.1526, 3.2969, 3.6764, 2.7728,
        3.4938, 2.6953, 4.1464, 3.3536, 3.4477, 4.1294, 3.4954, 3.3877, 3.5933,
        3.4263, 3.2547, 3.3897, 3.0738, 3.8040, 3.0195, 3.1208, 3.2751, 2.9413,
        3.7101, 3.6262, 2.9235, 3.3959, 3.6894, 3.7830, 3.4623, 3.4566, 2.9977,
        3.0198, 3.0355, 3.2235, 3.6272, 3.5818, 3.2053, 3.4514, 3.3535, 3.1393,
        2.5718, 2.0719, 3.0380, 2.5163, 2.8320, 2.1598, 3.0866, 3.3315, 3.3731,
        2.9848, 3.4181, 2.9928, 3.1267, 3.0138, 3.6353, 3.3168, 2.9110, 3.1164,
        3.5157, 2.5362, 3.3207, 3.3088, 3.2438, 3.1809, 3.1998, 3.4505, 3.8424,
        3.3883, 3.7903, 2.9894, 3.0427, 2.0168, 2.9740, 3.4137, 2.5979, 2.4448],
       device='cuda:0')

In [83]:
result.sum()

tensor(288.1718, device='cuda:0')

In [84]:
train_inds_list = []
for idx, (im, mask, w) in enumerate(traindata_loader):
    train_inds_list.append(w)
    
train_inds = train_inds_list[0]

y = mask.squeeze()
train_inds = train_inds.squeeze()

final_feat = torch.randn(103, 624, 352).cuda()

cls_list = torch.unique(y)

print(cls_list)

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])


In [85]:
num_cls = len(cls_list) - 1
num_train = int(train_inds.sum())

In [86]:
num_cls, num_train

(9, 90)

In [87]:
location = torch.where(train_inds == 1.)
label = y[location]

In [88]:
feat_dict_per_class = dict()
train_per_class = []
for i in range(1, num_cls+1):
    feat_inds = torch.where(label==i)[0]
    train_per_class.append(len(feat_inds))
    feat_dict_per_class[i] = final_feat[:, location[0][feat_inds], location[1][feat_inds]]

In [89]:
feat = feat_dict_per_class[1]
for i in range(1, num_cls):
    feat = torch.cat([feat, feat_dict_per_class[i+1]], dim=1)

In [90]:
distance = torch.zeros((num_train, num_train), device=torch.device("cuda:0"))

for i in range(num_train):
    anchor = torch.unsqueeze(feat[:, i], 1)
    distance[i, :] = torch.norm(anchor - feat, p=2, dim=0)

In [91]:
ck = dict()
for i in range(1, num_cls+1):
    ck[i] = feat_dict_per_class[i].mean(dim=1).unsqueeze(dim=1)

In [92]:
centers = ck[1]
for i in range(1, num_cls):
    centers = torch.cat([centers, ck[i+1]], dim=1)

In [93]:
radius = torch.norm(centers, p=2, dim=0).mean()
radius

tensor(3.2014, device='cuda:0')

In [94]:
centers[:, 1]

tensor([-0.1706,  0.0887, -0.5220, -0.0099, -0.2277, -0.1112, -0.3236,  0.0628,
        -0.1747,  0.1720,  0.4116, -0.0681,  0.2853, -0.1676,  0.0126,  0.4607,
        -0.2436,  0.1084,  0.4405, -0.2513,  0.0321, -0.2306, -0.3630, -0.2676,
         0.1296, -0.2641, -0.1007,  0.5824,  0.6994, -0.0352, -0.5557,  0.7687,
         0.3566, -0.2765,  0.0725,  0.6493,  0.2286, -0.2276, -0.0267,  0.2053,
         0.1496,  0.4600, -0.4200,  0.2003,  0.1398,  0.0827, -0.1850, -0.2661,
         0.2280, -0.1412, -0.0511, -0.0246, -0.2369,  0.4583, -0.0242, -0.3626,
         0.0987, -0.7115, -0.2526,  0.0406,  0.0855,  0.1035, -0.3285,  0.1473,
         0.2018, -0.3990,  0.3077,  0.5156, -0.5354,  0.0314, -0.3164, -0.1999,
        -0.1038, -0.5451, -0.1978, -0.1962,  0.5864,  0.3214,  0.2991, -0.1843,
        -0.0469, -0.2046, -0.0845, -0.2531,  0.1264,  0.6228,  0.2213,  0.2646,
        -0.2768, -0.0662,  0.1092,  0.5880, -0.2813, -0.0418,  0.3219,  0.1254,
         0.1677, -0.2811, -0.1233,  0.13

In [95]:
torch.norm(centers, p=2, dim=0)

tensor([3.7301, 3.0787, 3.3628, 2.9282, 2.8457, 3.4343, 3.4883, 2.9951, 2.9491],
       device='cuda:0')

In [96]:
centers_trans = centers / torch.norm(centers, p=2, dim=0)

In [97]:
centers_trans[:, 1]

tensor([-0.0554,  0.0288, -0.1696, -0.0032, -0.0740, -0.0361, -0.1051,  0.0204,
        -0.0567,  0.0559,  0.1337, -0.0221,  0.0927, -0.0544,  0.0041,  0.1496,
        -0.0791,  0.0352,  0.1431, -0.0816,  0.0104, -0.0749, -0.1179, -0.0869,
         0.0421, -0.0858, -0.0327,  0.1892,  0.2272, -0.0114, -0.1805,  0.2497,
         0.1158, -0.0898,  0.0235,  0.2109,  0.0742, -0.0739, -0.0087,  0.0667,
         0.0486,  0.1494, -0.1364,  0.0651,  0.0454,  0.0268, -0.0601, -0.0864,
         0.0741, -0.0459, -0.0166, -0.0080, -0.0770,  0.1489, -0.0079, -0.1178,
         0.0321, -0.2311, -0.0820,  0.0132,  0.0278,  0.0336, -0.1067,  0.0478,
         0.0656, -0.1296,  0.0999,  0.1675, -0.1739,  0.0102, -0.1028, -0.0649,
        -0.0337, -0.1771, -0.0643, -0.0637,  0.1905,  0.1044,  0.0972, -0.0599,
        -0.0152, -0.0664, -0.0274, -0.0822,  0.0411,  0.2023,  0.0719,  0.0859,
        -0.0899, -0.0215,  0.0355,  0.1910, -0.0914, -0.0136,  0.1045,  0.0407,
         0.0545, -0.0913, -0.0400,  0.04

In [98]:
centers_trans = centers_trans * radius

In [99]:
centers_trans[:, 1] / radius

tensor([-0.0554,  0.0288, -0.1696, -0.0032, -0.0740, -0.0361, -0.1051,  0.0204,
        -0.0567,  0.0559,  0.1337, -0.0221,  0.0927, -0.0544,  0.0041,  0.1496,
        -0.0791,  0.0352,  0.1431, -0.0816,  0.0104, -0.0749, -0.1179, -0.0869,
         0.0421, -0.0858, -0.0327,  0.1892,  0.2272, -0.0114, -0.1805,  0.2497,
         0.1158, -0.0898,  0.0235,  0.2109,  0.0742, -0.0739, -0.0087,  0.0667,
         0.0486,  0.1494, -0.1364,  0.0651,  0.0454,  0.0268, -0.0601, -0.0864,
         0.0741, -0.0459, -0.0166, -0.0080, -0.0770,  0.1489, -0.0079, -0.1178,
         0.0321, -0.2311, -0.0820,  0.0132,  0.0278,  0.0336, -0.1067,  0.0478,
         0.0656, -0.1296,  0.0999,  0.1675, -0.1739,  0.0102, -0.1028, -0.0649,
        -0.0337, -0.1771, -0.0643, -0.0637,  0.1905,  0.1044,  0.0972, -0.0599,
        -0.0152, -0.0664, -0.0274, -0.0822,  0.0411,  0.2023,  0.0719,  0.0859,
        -0.0899, -0.0215,  0.0355,  0.1910, -0.0914, -0.0136,  0.1045,  0.0407,
         0.0545, -0.0913, -0.0400,  0.04

In [100]:
acc_index = np.zeros(num_cls).astype(np.int16)
for i in range(num_cls):
    acc_index[i] = sum(train_per_class[:i+1])

In [101]:
alpha = 1.0
result = torch.zeros(num_train).cuda()
for i in range(90):
    position = bisect_fn(acc_index, i)
    if position == 0:
        left_index = 0
    else:
        left_index = acc_index[position-1]
    right_index = acc_index[position]
    #print(left_index, right_index)
    intra_dist = distance[i, left_index:right_index]
    inter_dist = torch.cat([distance[i, :left_index], distance[i, right_index:]])
    #print(intra_dist.size(), inter_dist.size())
    #print(torch.topk(intra_dist, k=4)[0], torch.topk(inter_dist, k=4, largest=False)[0])
    positive = (torch.topk(intra_dist, k=4)[0]).sum() / 4
    negative = (torch.topk(inter_dist, k=4, largest=False)[0]).sum() / 4
    result[i] = alpha + positive - negative
    if result[i] < 2.0:
        result[i] = 0.0

In [102]:
class_ind = bisect_fn(acc_index, 89)
class_ind

8

In [103]:
torch.min(torch.norm((feat[:, 0].unsqueeze(dim=1)-centers), p=2, dim=0), dim=0)[1]

tensor(0, device='cuda:0')

In [104]:
centers_trans[:, 0].size()

torch.Size([103])

In [105]:
num_train

90

In [106]:
result = torch.zeros(num_train).cuda()
for i in range(num_train):
    class_ind = bisect_fn(acc_index, i)
    #print(class_ind)
    if class_ind == 0:
        negative_centers = centers[:, 1:]
    else:
        negative_centers = torch.cat([centers[:, :class_ind], centers[:, class_ind+1:]], dim=1)
    #print(negative_centers.size())
    closest_negative_center_index = torch.min(torch.norm(feat[:, i].unsqueeze(dim=1) - negative_centers, 
                                                         p=2, dim=0), dim=0)[1]
    closest_negative_center = negative_centers[:, closest_negative_center_index]
    closest_negative_center_trans = closest_negative_center / torch.norm(closest_negative_center, p=2) * radius
    interclass_dists = torch.square(torch.dist(centers_trans[:, class_ind], closest_negative_center_trans, p=2))
    #print(centers_dists)
    intraclass_dists = torch.square(torch.dist(feat[:, i], centers_trans[:, class_ind], p=2))
    result[i] =  intraclass_dists - interclass_dists * 0.2 + radius **2 / 2
    if result[i] < 0.0:
        result[i] = 0.0
    #print(result[i])
print(result.sum() / (2 * num_train))

tensor(46.7359, device='cuda:0')


In [107]:
closest_negative_center_index = torch.min(torch.norm(feat[:, 0].unsqueeze(dim=1) - centers, 
                                                         p=2, dim=0), dim=0)[1]

In [108]:
closest_negative_center_index

tensor(0, device='cuda:0')

In [109]:
torch.square(torch.dist(centers_trans[:, 0], centers_trans[:, closest_negative_center_index], p=2))

tensor(0., device='cuda:0')

In [110]:
variance_loss = torch.tensor(0.).cuda()
for i in range(1, num_cls+1):
    zj_ck = feat_dict_per_class[i] - ck[i]  # (103, num_train)
    variance_loss += zj_ck.mul(zj_ck).sum() / (zj_ck.size()[1] - 1)
variance_loss = variance_loss / num_cls

In [111]:
variance_loss

tensor(102.3473, device='cuda:0')

In [112]:
feat_dimension = final_feat.size()[0]  # 103

In [113]:
feat_dimension

103

In [114]:
lamb = 1e-10
delta = 0.
diver_loss = torch.tensor(0.).cuda()
for k in range(1, num_cls+1):
    Sk = torch.zeros(feat_dimension, feat_dimension).cuda()
    zj_ck = feat_dict_per_class[k] - ck[k]
    nk = zj_ck.size()[1]
    for i in range(nk):
        Sk += torch.mm(zj_ck[:, i].unsqueeze(dim=1), zj_ck[:, i].unsqueeze(dim=0))
    for t in range(k+1, num_cls+1):
        #print(k, t)
        St = torch.zeros(feat_dimension, feat_dimension).cuda()
        zj_ct = feat_dict_per_class[t] - ck[t]
        nt = zj_ct.size()[1]
        for j in range(nt):
            St += torch.mm(zj_ct[:, j].unsqueeze(dim=1), zj_ct[:, j].unsqueeze(dim=0))
        
        ck_ct = ck[k] - ck[t]
        diver_loss += delta - torch.mm(torch.mm(ck_ct.transpose(1, 0), torch.inverse(Sk + St)), 
                                       ck_ct).squeeze() * (nk*nt - 2*nk*nt/(nk+nt))
diver_loss = diver_loss * lamb

In [115]:
diver_loss

tensor(0.5661, device='cuda:0')

In [116]:
diver_loss = torch.tensor(0.).cuda()
for k in range(1, num_cls+1):
    for t in range(k+1, num_cls+1):
        diver_loss -= torch.dist(ck[k], ck[t], p=2)
diver_loss * 1e-3

tensor(-0.1621, device='cuda:0')

In [117]:
diver_loss = torch.tensor(0.).cuda()
for k in range(1, num_cls+1):
    for t in range(k+1, num_cls+1):
        diver_loss -= torch.dist(ck[k], ck[t], p=1)
diver_loss * 1e-4

tensor(-0.1318, device='cuda:0')