In [1]:
%run include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb
%run Medical_Utility.ipynb

from torch.autograd import Variable
from sklearn.metrics import roc_auc_score
from scipy.ndimage import zoom

class CNN(object):
    
    def __init__(self, general, arch):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        reduction = general["reduction"]
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        torch.manual_seed(random.randint(1, 10000))
        
        self.input_dims1, layers1 = Net.parse_layers(arch[0])
        #self.input_dims2, layers2 = Net.parse_layers(arch[1])
        #self.net = Network_template_hook(gpu_num, layers1, layers2).to(self.device)
        self.net = Network_template(gpu_num, layers1).to(self.device)
        
#         print(self.net.main[0].weight)
#         print("=================================================================")
        
#         from torchvision import models
#         self.net = models.resnet50(pretrained=False).to(self.device)
#         Net.init_weights(self.net, "kaiming")
    
#         print(self.net.main[0].weight)
#         print("=================================================================")
        
        self.criterion = StandardLoss(loss_mode, reduction).to(self.device)
        #self.optimizer = optim.Adam(self.net.parameters(), lr=lr, betas=(beta1,beta2))
        #self.optimizer = optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-2)
        self.optimizer = optim.SGD(self.net.parameters(), lr=0.001)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer)
        
    def optimize_step(self, Dtrain, labels):
        #self.net.zero_grad()
        self.optimizer.zero_grad()
        loss = self.criterion([self.net, Dtrain, labels])
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def gradual_outer_product(self, vec, interval):
        num = int(np.ceil(len(vec)/interval))
        res = [None] * num
        start = 0
        end = min(len(vec), interval)
        for i in range(num):
            outer_partial = np.outer(vec[start:end], vec)
            res[i] = outer_partial.dot(vec)
            start = start + interval
            end = min(len(vec), start+interval)
        res = np.concatenate(res, axis=0)
        return res

    def sparse_attention_batch_NCWHD(self, data):
        batch_size = data.shape[0]
        res = [None] * batch_size
        for idx in range(batch_size):
            item = np.squeeze(data[idx])
            num_nonzero = np.prod(item.shape) - np.sum(item == np.amin(item))
            print(num_nonzero)
            print(num_nonzero/np.prod(item.shape))
            shifted_item = item - np.amin(item)
            nonz_dim0, nonz_dim1, nonz_dim2 = np.nonzero(shifted_item)
            assert(num_nonzero == len(nonz_dim0))
            flatten_ = np.zeros(num_nonzero)

            for i in range(num_nonzero):
                flatten_[i] = item[nonz_dim0[i]][nonz_dim1[i]][nonz_dim2[i]]
            outer = self.gradual_outer_product(flatten_, 5000)
            item_res = np.ones(item.shape) * np.amin(item)
            for i in range(num_nonzero):
                item_res[nonz_dim0[i]][nonz_dim1[i]][nonz_dim2[i]] = outer[i]
            item_res = np.expand_dims(item_res, axis=0)
            res[idx] = item_res
        res = np.concatenate(res, axis=0)
        res = np.expand_dims(res, axis=1)
        return res
        
    def train(self, data_params, branch_name="Undefined Here"):
        
        epochs           = data_params["epochs"]
        batch_size       = data_params["batch_size"]
        batch_workers    = data_params["batch_workers"]
        shuffle          = data_params["shuffle"]
        drop_last        = data_params["drop_last"]
        datasplit_scheme = data_params["datasplit_scheme"]
        test_split       = data_params["test_split"]
        xfold            = data_params["xfold"]
        fold_idx         = data_params["fold_idx"]
        random_seed      = data_params["random_seed"]
        train_loader = Data_fetcher.fetch_dataset(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, test_split, random_seed)
        #train_loader, test_loader = Data_fetcher.fetch_dataset(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, test_split, random_seed)
        #train_loader, test_loader = Data_fetcher.fetch_dataset_wValidation(FLAGS.dataset, FLAGS.data_path, batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
        log           = open(FLAGS.log_path, "a")
        log.write('Branch: %s  Fold ID: %d\n\n' % (branch_name, fold_idx))
        log.flush()
        
        step = 0
        if FLAGS.continue_model:
            self.net.load_state_dict(torch.load('%s/net_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
            
        lrec = []
        best_f1 = -1.0
        best_step = 0
        stablize_step = 1000
        for epoch in range(epochs):
            epoch_loss = 0
            for i, data in enumerate(train_loader, 0):
                vol = data['img'].to(self.device)
                labels = data['label'].to(self.device)
              
                if labels.shape[0] == 1:
                    continue
                loss_ = self.optimize_step(vol, labels)
                lrec.append(loss_)
                epoch_loss = epoch_loss + loss_
                step = step + 1
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] loss: %.4f Step: %d'
                      %(epoch, epochs, i, len(train_loader), np.mean(np.asarray(lrec)), step))
                    lrec[:] = []
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                    
#                     label_pred_cum = np.array((),dtype=np.int32)
#                     label_test_cum = np.array((),dtype=np.int32)
#                     for j, data_test in enumerate(test_loader, 0):
#                         vol_test   = data_test['vol'].unsqueeze(1).to(self.device)
#                         label_test = data_test['label']
#                         if label_test.shape[0] == 1:
#                             continue
#                         pred_test  = self.net(vol_test)
#                         predicted  = torch.max(pred_test.data, 1)[1]
#                         label_pred_cum = np.concatenate((label_pred_cum, predicted.detach().cpu()))
#                         label_test_cum = np.concatenate((label_test_cum, label_test))

#                     accuracy = Utility_MEDICAL.compute_accuracy(label_test_cum,label_pred_cum)
#                     balanced_accuracy = Utility_MEDICAL.binary_balanced_evaluation(label_test_cum,label_pred_cum)
#                     specificity, sensitivity = Utility_MEDICAL.compute_specificity_sensitivity(label_test_cum,label_pred_cum)
#                     f1_score = Utility_MEDICAL.compute_F1(label_test_cum, label_pred_cum)
#                     auc = roc_auc_score(label_test_cum, label_pred_cum)
                    
#                     if step > stablize_step:
#                         if f1_score > best_f1:
#                             best_f1 = f1_score
#                             best_step = step
                    
#                     msg0 = ('Test accurady: %.4f  Balanced_accuracy: %.4f'%(accuracy, balanced_accuracy))
#                     msg1 = ('Specificity: %.4f  Sensitivity: %.4f'%(specificity, sensitivity))
#                     msg2 = ('AUC: %.4f\nF1 score: %.4f'%(auc, f1_score))
#                     msg3 = ('Best F1 score: %.4f  Step: %d' %(best_f1, best_step))
#                     print(msg0)
#                     print(msg1)
#                     print(msg2)
#                     print(msg3)
#                     log.write(msg0+"\n")
#                     log.write(msg1+"\n")
#                     log.write(msg2+"\n")
#                     log.write(msg3+"\n")
#                     log.flush()
                        
                if step % FLAGS.save_step == 0:
                    # ===== Save models ====
                    torch.save(self.net.state_dict(), '%s/net_step_%d.pth' % (FLAGS.model_save, step))
            self.scheduler.step(epoch_loss)
        log.close()
        print("Training complete.")
        
    def test(self, data_params):
        
        epochs        = data_params["epochs"]
        batch_size    = data_params["batch_size"]
        batch_workers = data_params["batch_workers"]
        shuffle       = data_params["shuffle"]
        drop_last     = data_params["drop_last"]
        test_split    = data_params["test_split"]
        random_seed   = data_params["random_seed"]
        train_loader = Data_fetcher.fetch_dataset(FLAGS.dataset, FLAGS.data_path,  batch_size, batch_workers, shuffle, drop_last, 0.5, test_split, random_seed)
        #train_loader, test_loader = Data_fetcher.fetch_dataset(FLAGS.dataset, FLAGS.data_path,  batch_size, batch_workers, shuffle, drop_last, 0.5, test_split, random_seed)
        log           = open(FLAGS.log_path, "a")
        
        step = 0
        if FLAGS.continue_model:
            self.net.load_state_dict(torch.load('%s/net_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
            

#         from torchvision import models
#         model = models.resnet50(pretrained=True)
#         model = models.vgg16(pretrained=True)
#         model = model.features[:-2]
#         model = models.densenet161(pretrained=False)
#         target_layer = model.features.denseblock4.denselayer24
#         target_layer = model.layer4
#         target_layer = model[-1]
#         model = self.net.main
#         #print(model)
#         #target_layer = self.net.layer4[-1].conv3
#         target_layer = self.net.main[-11]
#         rgb_img = cv2.imread("E:/Data2/Dog_Cat/train_subset2/cat.33.jpg", 1)[:,:,::-1]
#         rgb_img = cv2.resize(rgb_img, (256, 256))
#         gradcam_guidedbackprop_visualize_single_image(rgb_img, model, target_layer, self.device)


        model = VGG16(n_classes=2)
        model.load_state_dict(torch.load('E:/Data2/Pytorch_Log/vgg16/cnn.pkl'))
        target_layer = model.layer5

        #rgb_img = np.float32(cv2.imread("E:/Data2/Dog_Cat/train_subset2/dog.248.jpg", 1))/255
        rgb_img = cv2.imread("E:/Data2/Dog_Cat/train_subset2/cat.133.jpg", 1)[:,:,::-1]
        rgb_img = cv2.resize(rgb_img, (224, 224))
        gradcam_guidedbackprop_visualize_single_image(rgb_img, model, target_layer, self.device)
    
    
    
#         correct = 0
#         total   = 0
#         self.net.eval()
#         for i, data in enumerate(train_loader, 0):
#             vol = data['img'].to(self.device)
#             labels = data['label'].numpy()
#             out = self.net(vol)
#             out = np.argmax(out.detach().cpu().numpy(), axis=1)
#             total = total + out.shape[0]
#             correct = correct + np.sum(out == labels)
# #             target_layer = self.net.main[-7]
# #             gradcam_guidedbackprop_visualize(vol[0,:], self.net.main, target_layer, self.device)
# #             break
#         print(correct * 1.0 / total)