<a href="https://colab.research.google.com/github/ucalyptus/HybridSN-Pytorch/blob/master/HybridSN-ColabVersion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from scipy import io 
import torch.utils.data
import scipy
from scipy.stats import entropy
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import math
from sklearn.metrics import mean_squared_error
from sklearn.decomposition import PCA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \
    NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
from torch.nn import functional as F

In [None]:
!pip install -U spectral


if not (os.path.isfile('/content/Indian_pines_corrected.mat')):
  !wget http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat
if not (os.path.isfile('/content/Indian_pines_gt.mat')):
  !wget http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat

Collecting spectral
[?25l  Downloading https://files.pythonhosted.org/packages/f5/ff/f6e238a941ed55079526996fee315fbee5167aaa64de3e64980637ac8f38/spectral-0.21-py3-none-any.whl (187kB)
[K     |█▊                              | 10kB 13.6MB/s eta 0:00:01[K     |███▌                            | 20kB 1.7MB/s eta 0:00:01[K     |█████▎                          | 30kB 2.2MB/s eta 0:00:01[K     |███████                         | 40kB 2.5MB/s eta 0:00:01[K     |████████▊                       | 51kB 2.0MB/s eta 0:00:01[K     |██████████▌                     | 61kB 2.3MB/s eta 0:00:01[K     |████████████▎                   | 71kB 2.5MB/s eta 0:00:01[K     |██████████████                  | 81kB 2.7MB/s eta 0:00:01[K     |███████████████▊                | 92kB 2.9MB/s eta 0:00:01[K     |█████████████████▌              | 102kB 2.8MB/s eta 0:00:01[K     |███████████████████▎            | 112kB 2.8MB/s eta 0:00:01[K     |█████████████████████           | 122kB 2.8MB/s eta 0

In [None]:
import scipy.io as sio
def loadData():
    data = sio.loadmat('Indian_pines_corrected.mat')['indian_pines_corrected']
    labels = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']
    
    return data, labels

In [None]:
def padWithZeros(X, margin=2):

    ## From: https://github.com/gokriznastic/HybridSN/blob/master/Hybrid-Spectral-Net.ipynb
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize=9, removeZeroLabels = True):

     ## From: https://github.com/gokriznastic/HybridSN/blob/master/Hybrid-Spectral-Net.ipynb
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]), dtype=np.uint8)
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]), dtype=np.uint8)
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]   
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]
            patchIndex = patchIndex + 1
    if removeZeroLabels:
        patchesData = patchesData[patchesLabels>0,:,:,:]
        patchesLabels = patchesLabels[patchesLabels>0]
        patchesLabels -= 1
    return patchesData, patchesLabels


In [None]:

def applyPCA(X, numComponents=75):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

In [None]:
from sklearn.model_selection import train_test_split
data_url,label_url = 'Indian_pines_corrected.mat','Indian_pines_gt.mat'
X = np.array(scipy.io.loadmat('/content/'+data_url.split('/')[-1])[data_url.split('/')[-1].split('.')[0].lower()])
y = np.array(scipy.io.loadmat('/content/'+label_url.split('/')[-1])[label_url.split('/')[-1].split('.')[0].lower()])
X,_ = applyPCA(X,numComponents=30)
X,y = createImageCubes(X,y, windowSize=25)
X = np.expand_dims(X, axis=1)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.70, stratify=y)
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).type(torch.FloatTensor),torch.from_numpy(y_train).type(torch.FloatTensor))
train_loader = DataLoader(train_dataset, batch_size=16)

test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val).type(torch.FloatTensor),torch.from_numpy(y_val).type(torch.FloatTensor))
test_loader = DataLoader(test_dataset, batch_size=16)


In [None]:
class HybridSN(nn.Module):
    def __init__(self,band,classes):
        super(HybridSN, self).__init__()
        self.conv3d_1 = nn.Sequential(nn.Conv3d(1, 8, (3,3,3)), 
                        nn.ReLU())
        
        self.conv3d_2 = nn.Sequential(nn.Conv3d(8, 16, (3,3,7)),
                        nn.ReLU())
                        
        self.conv3d_3 = nn.Sequential(nn.Conv3d( 16,32, (3,3,5)),
                        nn.ReLU())
        self.conv2d_1 = nn.Sequential(nn.Conv2d( 576,64, (3, 3)),
                        nn.ReLU())
        
        self.dense1 =  nn.Linear(18496,256)
        self.dense2 =  nn.Linear(256,128)
        self.full = nn.Linear(128,classes)
        self.drop = nn.Dropout(p=0.4)
        self.soft = nn.Softmax(dim=-1)

        

    def forward(self, x):
        x = self.conv3d_1(x)
        
        x = self.conv3d_2(x)
        
        x = self.conv3d_3(x)
        
        batches,Q,H,W,C = x.size()
        
        x = x.view(batches,Q*C,H,W)
        
        x = self.conv2d_1(x)
        
        x = x.reshape(batches,-1)
        
        x = self.dense1(x)
        x = self.drop(x)
        x = self.dense2(x)
        x = self.drop(x)
        x = self.full(x)
        

        return self.soft(x)
net = HybridSN(30,16).to(device)
from torchsummary import summary
print(summary(net,(1,25,25,30),batch_size=16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [16, 8, 23, 23, 28]             224
              ReLU-2        [16, 8, 23, 23, 28]               0
            Conv3d-3       [16, 16, 21, 21, 22]           8,080
              ReLU-4       [16, 16, 21, 21, 22]               0
            Conv3d-5       [16, 32, 19, 19, 18]          23,072
              ReLU-6       [16, 32, 19, 19, 18]               0
            Conv2d-7           [16, 64, 17, 17]         331,840
              ReLU-8           [16, 64, 17, 17]               0
            Linear-9                  [16, 256]       4,735,232
          Dropout-10                  [16, 256]               0
           Linear-11                  [16, 128]          32,896
          Dropout-12                  [16, 128]               0
           Linear-13                   [16, 16]           2,064
          Softmax-14                   

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.001)
ce_loss = nn.CrossEntropyLoss()

In [None]:
len(train_loader.dataset),len(test_loader.dataset)

(3074, 7175)

In [None]:
def train(epoch):    
    net.train()
    for batch_idx, (data, targets ) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = net(data)
        loss = ce_loss(output,targets.long())
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),loss.item()))
          


for epoch in range(0, 100):
    train(epoch)
    



KeyboardInterrupt: ignored

In [None]:
def sampling(proportion, ground_truth):
    train = {}
    test = {}
    labels_loc = {}
    m = max(ground_truth)
    for i in range(m):
        indexes = [j for j, x in enumerate(ground_truth.ravel().tolist()) if x == i + 1]
        np.random.shuffle(indexes)
        labels_loc[i] = indexes
        if proportion != 1:
            nb_val = max(int((1 - proportion) * len(indexes)), 3)
        else:
            nb_val = 0
        # print(i, nb_val, indexes[:nb_val])
        # train[i] = indexes[:-nb_val]
        # test[i] = indexes[-nb_val:]
        train[i] = indexes[:nb_val]
        test[i] = indexes[nb_val:]
    train_indexes = []
    test_indexes = []
    for i in range(m):
        train_indexes += train[i]
        test_indexes += test[i]
    np.random.shuffle(train_indexes)
    np.random.shuffle(test_indexes)
    return train_indexes, test_indexes


In [None]:
pred_test    = []
with torch.no_grad():
    for X, y in validation_loader:
        X = X.to(device)
        net.eval() 
        y_hat = net(X)
        pred_test.extend(np.array(net(X).cpu().argmax(axis=1)))



collections.Counter(pred_test)
gt_test = gt[test_indices] - 1
overall_acc    = metrics.accuracy_score(pred_test   , gt_test[:-VAL_SIZE])
confusion_matrix    = metrics.confusion_matrix(pred_test , gt_test[:-VAL_SIZE])
each_acc   , average_acc    = aa_and_each_accuracy(confusion_matrix)
kappa = metrics.cohen_kappa_score(pred_test   , gt_test[:-VAL_SIZE])
torch.save(net.state_dict(), "./net/" + str(round(overall_acc   , 3)) + '.pt')
KAPPA.append(kappa)
OA.append(overall_acc)
AA.append(average_acc)
ELEMENT_ACC[index_iter, :] = each_acc   

print("--------" + net.name + " Training Finished-----------")
record.record_output(OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME,
                     'records/' + method + '_' + Dataset + '_' +str(BAND)+ '_'  + str(VALIDATION_SPLIT)  + '.txt')
