In [1]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from evaluation import compute_rmse, compute_sad
from utils import print_args, SparseLoss, NonZeroClipper, MinVolumn
from data_loader import set_loader
from model import Init_Weights, MUNet

import matplotlib.pyplot as plt
import scipy.io as sio
import numpy as np
import argparse
import random
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--fix_random', action='store_true', help='fix randomness')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--gpu_id', default='0,1,2', help='gpu id')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--patch', default=1, type=int, help='input data size')
parser.add_argument('--learning_rate_en', default=3e-4, type=float, help='learning rate of encoder')
parser.add_argument('--learning_rate_de', default=1e-4, type=float, help='learning rate of decoder')
parser.add_argument('--weight_decay', default=1e-5, type=float, help='network parameter regularization')
parser.add_argument('--lamda', default=0, type=float, help='sparse regularization')
parser.add_argument('--reduction', default=2, type=int, help='squeeze reduction')
parser.add_argument('--delta', default=0, type=float, help='delta coefficient')
parser.add_argument('--gamma', default=0.8, type=float, help='learning rate decay')
parser.add_argument('--epoch', default=200, type=int, help='number of epoch')
parser.add_argument('--dataset', choices=['muffle','houston170'], default='muffle', help='dataset to use')
args = parser.parse_args(args = [])

In [3]:
train_loaders, test_loaders, label, M_init, M_true, num_classes, band, col, row, ldr_dim = set_loader(args)

**************************************************
patch is : 1
mirror_image shape : [130,90,64]
mirror_label shape : [130,90,5]
**************************************************


In [4]:
net = MUNet(band, num_classes, ldr_dim, args.reduction)

In [5]:
if args.dataset == 'muffle':
    position = np.array([0,2,1,3,4]) # muffle
    Init_Weights(net,'xavier', 1)
elif args.dataset == 'houston170': 
    position = np.array([0,1,2,3]) # houston170
    Init_Weights(net,'xavier', 1)

Init Network Weights
initialize network with xavier


In [6]:
net_dict = net.state_dict()
net_dict['decoder.0.weight'] = M_init
net.load_state_dict(net_dict)

# loss funtion and regularization
apply_nonegative = NonZeroClipper()
loss_func = nn.MSELoss()
criterionSparse = SparseLoss(args.lamda)
criterionVolumn = MinVolumn(band, num_classes, args.delta)

In [7]:
# optimizer setting
params = map(id, net.decoder.parameters())
ignored_params = list(set(params))      
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 
optimizer = torch.optim.Adam([{'params': base_params},{'params': net.decoder.parameters(), 'lr': args.learning_rate_de}],
                                lr = args.learning_rate_en, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=args.gamma)

In [8]:
for i, traindata in enumerate(train_loaders):        
    net.train()
    x, y = traindata
    abu, output = net(x,y)


encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.Size([128, 5, 1, 1])
encode.shape: torch.Size([128, 5, 1, 1])
attention.shape: torch.S