<a href="https://colab.research.google.com/github/ucalyptus/BS-Nets-Implementation-Pytorch/blob/master/BSNets_with_Dual_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from scipy import io 
import torch.utils.data
import scipy
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import math
from sklearn.metrics import mean_squared_error

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [160]:
!pip install -U spectral
!pip install pytorch_ssim
from pytorch_ssim import ssim
if not (os.path.isfile('/content/Indian_pines_corrected.mat')):
  !wget http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat
if not (os.path.isfile('/content/Indian_pines_gt.mat')):
  !wget http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat

Requirement already up-to-date: spectral in /usr/local/lib/python3.6/dist-packages (0.20)


In [0]:
from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \
    NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
from torch.nn import functional as F

In [0]:
def padWithZeros(X, margin=2):

    ## From: https://github.com/gokriznastic/HybridSN/blob/master/Hybrid-Spectral-Net.ipynb
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):

     ## From: https://github.com/gokriznastic/HybridSN/blob/master/Hybrid-Spectral-Net.ipynb
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]), dtype=np.uint8)
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]), dtype=np.uint8)
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]   
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]
            patchIndex = patchIndex + 1
    if removeZeroLabels:
        patchesData = patchesData[patchesLabels>0,:,:,:]
        patchesLabels = patchesLabels[patchesLabels>0]
        patchesLabels -= 1
    return patchesData, patchesLabels


In [0]:
class HyperSpectralDataset(Dataset):
    """HyperSpectral dataset."""

    def __init__(self,data_url,label_url):
        
        self.data = np.array(scipy.io.loadmat('/content/'+data_url.split('/')[-1])[data_url.split('/')[-1].split('.')[0].lower()])
        self.targets = np.array(scipy.io.loadmat('/content/'+label_url.split('/')[-1])[label_url.split('/')[-1].split('.')[0].lower()])
        self.data, self.targets = createImageCubes(self.data,self.targets, windowSize=5)
        
        #self.data = self.data.reshape((-1, self.data.shape[1], self.data.shape[2], self.data.shape[3], 1))
        
        self.data = self.data[:10240,:,:,:]
        self.targets = self.targets[:10240]
        self.data = torch.Tensor(self.data)
        self.data = self.data.permute(0,3,1,2)

    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
      
      return self.data[idx,:,:,:] , self.targets[idx]


In [0]:
data_train = HyperSpectralDataset('Indian_pines_corrected.mat','Indian_pines_gt.mat')
train_loader = DataLoader(data_train, batch_size=64, shuffle=True)

In [165]:
"""class BAM(nn.Module):
    def __init__(self):
      
        super(BAM, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(200, 64, (3, 3), 1, 0),
                                   nn.ReLU(True))

        self.fc1 = nn.Sequential(nn.Linear(64,128),
                                 nn.ReLU(True))
    
        self.fc2 = nn.Sequential(nn.Linear(128,200),
                                 nn.Sigmoid())
                    
    def forward(self,x):
            
        x = self.conv1(x)
        print(x.shape)
        x = F.avg_pool2d(x, x.size()[2:4])
        print(x.shape)
        x = x.view(-1, 64)
        print(x.shape)
        x = self.fc1(x)
        
        x = self.fc2(x)
        print(x.shape)
        return x.unsqueeze(2).unsqueeze(3)     """

'class BAM(nn.Module):\n    def __init__(self):\n      \n        super(BAM, self).__init__()\n        self.conv1 = nn.Sequential(nn.Conv2d(200, 64, (3, 3), 1, 0),\n                                   nn.ReLU(True))\n\n        self.fc1 = nn.Sequential(nn.Linear(64,128),\n                                 nn.ReLU(True))\n    \n        self.fc2 = nn.Sequential(nn.Linear(128,200),\n                                 nn.Sigmoid())\n                    \n    def forward(self,x):\n            \n        x = self.conv1(x)\n        print(x.shape)\n        x = F.avg_pool2d(x, x.size()[2:4])\n        print(x.shape)\n        x = x.view(-1, 64)\n        print(x.shape)\n        x = self.fc1(x)\n        \n        x = self.fc2(x)\n        print(x.shape)\n        return x.unsqueeze(2).unsqueeze(3)     '

In [0]:
"""class BSNET_Conv(nn.Module):
  
    def __init__(self):
        super(BSNET_Conv, self).__init__()
        self.BAM = BAM()
        self.RecNet = RecNet()

    def forward(self,x):
        #print('before bam ', x.shape)
        BRW = self.BAM(x)
        x = x * BRW
        #print('after bam ',x.shape)
        
        x = x.unsqueeze(1)
        #print(x.shape)
        ret = self.RecNet(x)
        #print('after reconstruction', ret.shape)
        
        return ret"""
model = BSNET_Conv().to(device)  

In [0]:
class PAM_Module(Module):
    """ Position attention module  https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py"""
    #Ref from SAGAN
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        
        self.gamma = Parameter(torch.zeros(1))

        self.softmax = Softmax(dim=-1)
    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        #out = F.avg_pool2d(out, out.size()[2:4])
        
        return out


class CAM_Module(Module):
    """ Channel attention module https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py"""
    def __init__(self):
        super(CAM_Module, self).__init__()
        #self.chanel_in = in_dim
        


        self.gamma = Parameter(torch.zeros(1))
        self.softmax  = Softmax(dim=-1)
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        #out = F.avg_pool2d(out, out.size()[2:4])
        
        
        return out


In [0]:
class RecNet(nn.Module):
    def __init__(self):
        super(RecNet, self).__init__()
        self.conv3d_1 = nn.Sequential(nn.Conv3d(1, 24, (24, 3, 3), 1),
                        nn.BatchNorm3d(24),
                        nn.PReLU())
        
        self.conv3d_2 = nn.Sequential(nn.Conv3d(24, 48, (24, 3, 3), 1),
                        nn.BatchNorm3d(48),
                        nn.PReLU())
                        
        
        self.pool3d = nn.MaxPool3d((18, 1, 1), (18, 1, 1))
        
        self.deconv3d_1 = nn.Sequential(nn.ConvTranspose3d(48, 24, (9, 3, 3), (22, 1, 1)),
                          nn.BatchNorm3d(24),
                          nn.PReLU())
        
        self.deconv3d_2 = nn.Sequential(nn.ConvTranspose3d(24, 1, (38, 3, 3), (1, 1, 1)),
                          nn.BatchNorm3d(1))

    def forward(self, x):
        x = self.conv3d_1(x)
        x = self.conv3d_2(x)
        
        x = self.pool3d(x)
        
        x = self.deconv3d_1(x)
        x = self.deconv3d_2(x)
        
        return x.squeeze(1)

In [0]:
class DANet(Module):
  def __init__(self):
    super(DANet,self).__init__()
    self.PAM_Module = PAM_Module(200)
    self.CAM_Module = CAM_Module()
    self.RecNet = RecNet()
  def forward(self,x):
    
    P = self.PAM_Module(x)
    C = self.CAM_Module(x)
    #B,Ch,H,W = P.size()
    J = P + C
    J =  J.unsqueeze(1)
    ret = self.RecNet(J)
    
    
    return ret
    
    
danet_model = DANet().to(device)  

In [0]:
def psnr(x_true, x_pred):
    n_samples = x_true.shape[0]
    n_bands = x_true.shape[1]
    PSNR = np.zeros(n_bands)
    MSE = np.zeros(n_bands)
    mask = np.ones(n_bands)
    for k in range(n_bands):
        x_true_k = x_true[:, k].reshape([-1])
        x_pred_k = x_pred[:, k].reshape([-1])
        MSE[k] = 1.0 / n_samples * mean_squared_error(x_true_k, x_pred_k, )
        MAX_k = np.max(x_true_k)
        if MAX_k != 0:
            PSNR[k] = 10 * math.log10(math.pow(MAX_k, 2) / MSE[k])
        else:
            mask[k] = 0

    psnr = PSNR.sum()/mask.sum()
    mse = MSE.mean()
    print('psnr', psnr)
    print('mse', mse)
    
    return psnr, mse

In [0]:
#model = BSNET_Conv().to(device) 

optimizer = optim.SGD(danet_model.parameters(), lr=0.002, momentum=0.9)

In [178]:
def train(epoch):
    danet_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        
        output = danet_model(data)
#         print(output.shape, data.shape)
        
        loss = F.l1_loss(output,data)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 50 == 0:
          
            x_true = data.detach().cpu().numpy()
            x_predict = output.detach().cpu().numpy()
            x_pred_centre = x_predict[:, :, 2, 2]
            x_true_centre = x_true[:, :, 2, 2]
            psnr(x_true_centre, x_pred_centre)
        
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        
        #if batch_idx % 100 == 0:
          #print(output.detach().cpu().numpy().shape)
            

def test():
    with torch.no_grad():
        danet_model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = danet_model(data)

            # sum up batch loss
            test_loss += F.mse_loss(output, target).item()
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

for epoch in range(1, 40 + 1):
    train(epoch)
    #test()

psnr 23.204136186580573
mse 308.616954870224
psnr 23.1119640668545
mse 314.63019690513613
psnr 23.23874639175122
mse 305.86190910339354
psnr 23.215297499212816
mse 302.82755693435666
psnr 23.255090482952646
mse 299.50838470458984
psnr 23.35677429961558
mse 297.2118961906433
psnr 23.43951715247009
mse 291.64262910842893
psnr 23.452602081013154
mse 287.484260597229
psnr 23.54886758553451
mse 286.53689643859866
psnr 23.42139574256693
mse 288.3711777591705
psnr 23.559123986951327
mse 281.70388145446776
psnr 23.628482217477984
mse 276.86778133392335
psnr 23.62374674574582
mse 276.4517353439331
psnr 23.627054519901602
mse 279.0852673530579
psnr 23.718352149394388
mse 272.15919989585876
psnr 23.746903157481448
mse 272.093681678772
psnr 23.77648067858097
mse 271.5836813926697
psnr 23.843455734209346
mse 267.62526463508607
psnr 23.779071489290217
mse 269.7297765350342
psnr 23.940472504980562
mse 261.8976565933228
psnr 23.926654022997678
mse 259.8370475101471
psnr 23.931572084582296
mse 259.2024

In [179]:
import torchsummary
torchsummary.summary(danet_model.to(device),(200,5,5))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 25, 5, 5]           5,025
            Conv2d-2             [-1, 25, 5, 5]           5,025
           Softmax-3               [-1, 25, 25]               0
            Conv2d-4            [-1, 200, 5, 5]          40,200
        PAM_Module-5            [-1, 200, 5, 5]               0
           Softmax-6             [-1, 200, 200]               0
        CAM_Module-7            [-1, 200, 5, 5]               0
            Conv3d-8        [-1, 24, 177, 3, 3]           5,208
       BatchNorm3d-9        [-1, 24, 177, 3, 3]              48
            PReLU-10        [-1, 24, 177, 3, 3]               1
           Conv3d-11        [-1, 48, 154, 1, 1]         248,880
      BatchNorm3d-12        [-1, 48, 154, 1, 1]              96
            PReLU-13        [-1, 48, 154, 1, 1]               1
        MaxPool3d-14          [-1, 48, 

In [0]:
"""import spectral
data_url , label_url = 'Indian_pines_corrected.mat' ,'Indian_pines_gt.mat'
X = np.array(scipy.io.loadmat('/content/'+data_url.split('/')[-1])[data_url.split('/')[-1].split('.')[0].lower()])
y = np.array(scipy.io.loadmat('/content/'+label_url.split('/')[-1])[label_url.split('/')[-1].split('.')[0].lower()])
view = spectral.imshow(X,(30,20,100), classes=y,figsize=(9,9))
view.set_display_mode('overlay')
view.class_alpha = 0.5"""

In [180]:

for param in danet_model.parameters():
    print(param.shape)

torch.Size([1])
torch.Size([25, 200, 1, 1])
torch.Size([25])
torch.Size([25, 200, 1, 1])
torch.Size([25])
torch.Size([200, 200, 1, 1])
torch.Size([200])
torch.Size([1])
torch.Size([24, 1, 24, 3, 3])
torch.Size([24])
torch.Size([24])
torch.Size([24])
torch.Size([1])
torch.Size([48, 24, 24, 3, 3])
torch.Size([48])
torch.Size([48])
torch.Size([48])
torch.Size([1])
torch.Size([48, 24, 9, 3, 3])
torch.Size([24])
torch.Size([24])
torch.Size([24])
torch.Size([1])
torch.Size([24, 1, 38, 3, 3])
torch.Size([1])
torch.Size([1])
torch.Size([1])


In [0]:
"""class PAM(nn.Module):
  def __init__(self):
    super(PAM,self).__init__()
    self.B = nn.Sequential(nn.Conv2d(200, 64, (3, 3), 1, 0),
                                   nn.ReLU(True))
    self.C = nn.Sequential(nn.Conv2d(200, 64, (3, 3), 1, 0),
                                   nn.ReLU(True))
    self.D = nn.Sequential(nn.Conv2d(200, 64, (3, 3), 1, 0),
                                   nn.ReLU(True))
    self.soft = nn.Softmax2d()
    self.alpha = nn.Variable(torch.ones(1, 1), requires_grad=True)
  
  def forward(self,x):
    A = x # C H W 
    b = self.B(x) # C H W 
    c = self.C(x) # C H W 
    d = self.D(x) # C H W 
    
    b = b.view(-1,b.size()[2]*b.size()[3]).T # N C 
    c = c.view(-1,c.size()[2]*c.size()[3]) #C N
    d = d.view(-1,d.size()[2]*d.size()[3]) #C N
    S = self.soft(torch.mm(b,c)) # N N
    sd = torch.mm(d,S) # C N 
    sd = sd.view(-1,x.size()[2],x.size()[3])
    E = torch.add(self.alpha*sd , A)
    
    return E
    
    
        
    """

In [0]:
"""class CAM(nn.Module):
  def __init__(self):
    super(CAM,self).__init__()
    self.soft = nn.Softmax2d()
    
  def forward(self,x):
    # x is C * H * W
    y = x.view(-1,x.size()[2]*x.size()[3]) #C * N
    X = self.soft(torch.mm(y,y.T))   # C * C
    Xy = torch.mm(X,y)     # C * N
    Xy = Xy.view(-1,x.size()[2],x.size()[3]) #C H W
    E = torch.add(Xy,x) # C H W
    
    
    
    return E
"""

In [0]:
"""import matplotlib.pyplot as plt
%matplotlib inline
X, y = createImageCubes(X, y, windowSize=15)
def plot(r):
  assert r<=10000
  fig, axes = plt.subplots(32, 32, figsize=(20, 20))
  itera = [*range(r)]
  for t,ax in zip(itera,axes.flatten()):
    ax.imshow(X[t,:,:,0])
    plt.subplots_adjust(wspace=.5, hspace=.5)
plot(1000)"""

In [0]:
def visualize_tile():
    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(train_loader))[0].to(device)
      
        input_tensor = data.cpu().numpy()
        
        transformed_input_tensor = model.RecNet(data).cpu().numpy()
        


        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2,figsize=(10,10))
        axarr[0].imshow(input_tensor[0,0,:,:],cmap='gnuplot')
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(transformed_tensor[0,0,:,:],cmap='gnuplot')
        axarr[1].set_title('Transformed Images')

#visualize_tile()

#plt.ioff()
#plt.show()