In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
import numpy as np
import sys
import scipy.ndimage as nd
import json
import pickle
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from resnet import *
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import time
import math
from utils import AverageMeter
import cv2
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from dr_model import DRModel

In [2]:
def initial_cls_weights(cls):
    for m in cls.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
            m.weight.data.normal_(0, math.sqrt(2./n))
            if m.bias is not None:
                m.bias.data.zero_()
        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()
        if isinstance(m, nn.Conv3d):
            n = m.kernel_size[0]*m.kernel_size[1]*m.kernel_size[2]*m.out_channels
            m.weight.data.normal_(0, math.sqrt(2./n))
            if m.bias is not None:
                m.bias.data.zero_()

In [3]:
class BinClsDataSet(torch.utils.data.Dataset):
    def __init__(self, config_file):
        self.images_list = []
        self.labels_list = []
        with open(config_file, 'r') as f:
            for line in f.readlines():
                line = line.strip()
                if line is None or len(line) == 0:
                    continue
                ss = line.split('\t')
                if len(ss) != 2:
                    continue
                if not os.path.isfile(ss[0]):
                    continue
                self.images_list.append(ss[0])
                self.labels_list.append(ss[1])
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
#             transforms.Normalize(mean=[0], std=[255])
        ])
    def __getitem__(self, item):
        image = self.transform(Image.open(self.images_list[item]))
        return image, int(self.labels_list[item]), self.images_list[item]
    def __len__(self):
        return len(self.images_list)

In [4]:
# ds = BinClsDataSet(config_file='/data/zhangwd/data/examples/dr/train_label.txt')
# dataloader = DataLoader(ds, batch_size=2, 
#                                      shuffle=True, num_workers=2, 
#                                      pin_memory=True)
# for i, (images, labels, _) in enumerate(dataloader):
#     print(images.shape)
#     print(labels)
#     break

In [None]:
# class DRModel(nn.Module):
#     def __init__(self, name, inmap, multi_classes, weights=None, scratch=False):
#         super(DRModel, self).__init__()
#         self.name = name
#         self.weights = weights
#         self.inmap = inmap
#         self.multi_classes = multi_classes
#         self.featmap = inmap // 32
#         self.planes = 2048
#         base_model = None
#         if name == 'rsn18':
#             base_model = resnet18()
#             self.planes = 512
#         elif name == 'rsn34':
#             base_model = resnet34()
#             self.planes = 512
#         elif name == 'rsn50':
#             base_model = resnet50()
#             self.planes = 2048
#         elif name == 'rsn101':
#             base_model = resnet101()
#             self.planes = 2048
#         elif name == 'rsn152':
#             base_model = resnet152()
#             self.planes = 2048

# #         if not scratch:
# #             base_model.load_state_dict(torch.load('../pretrained/'+name+'.pth'))
            
#         self.base = nn.Sequential(*list(base_model.children())[:-2])
#         if name == 'rsn18' or name == 'rsn34' or name == 'rsn50' or name == 'rsn101' or name == 'rsn152':
#             self.base = nn.Sequential(*list(base_model.children())[:-2])
#         elif name == 'dsn121' or name == 'dsn161' or name == 'dsn169' or name == 'dsn201':
#             self.base = list(base_model.children())[0]

#         self.cls = nn.Linear(self.planes, multi_classes)

#         initial_cls_weights(self.cls)

#         if weights:
#             self.load_state_dict(torch.load(weights))

#     def forward(self, x):
#         feature = self.base(x)
#         # when 'inplace=True', some errors occur!!!!!!!!!!!!!!!!!!!!!!
#         out = F.relu(feature, inplace=False)
#         out = F.avg_pool2d(out, kernel_size=self.featmap).view(feature.size(0), -1)
#         out = self.cls(out)
#         return out

In [None]:
# model = DRModel('rsn34', 1024, 2)
# indata = torch.randn(2,1,1024,1024)
# model(indata)

In [None]:
def train(train_dataloader, model, criterion, optimizer, epoch, display):
    model.train()
    tot_pred = np.array([], dtype=int)
    tot_label = np.array([], dtype=int)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()
    end = time.time()
    logger = []
    for num_iter, (images, labels, _) in enumerate(train_dataloader):
        data_time.update(time.time()-end)
        output = model(Variable(images.cuda()))
        loss = criterion(output, Variable(labels.cuda()))
        _, pred = torch.max(output, 1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time()-end)
        end = time.time()
        pred = pred.cpu().data.numpy()
        labels = labels.numpy()
        tot_pred = np.append(tot_pred, pred)
        tot_label = np.append(tot_label, labels)
        losses.update(loss.data.cpu().numpy(), len(images))
        accuracy.update(np.equal(pred, labels).sum()/len(labels), len(labels))
        if (num_iter+1) % display == 0:
            correct = np.equal(tot_pred, tot_label).sum()/len(tot_pred)
            print_info = 'Epoch: [{0}][{1}/{2}]\tTime {batch_time.val:3f} ({batch_time.avg:.3f})\t'\
                'Data {data_time.avg:.3f}\t''Loss {loss.avg:.4f}\tAccuray {accuracy.avg:.4f}'.format(
                epoch, num_iter, len(train_dataloader),batch_time=batch_time, data_time=data_time,
                loss=losses, accuracy=accuracy
            )
            print(print_info)
            logger.append(print_info)
    return accuracy.avg, logger

In [None]:
def val(train_dataloader, model, criterion, optimizer, epoch, display):
    model.eval()
    tot_pred = np.array([], dtype=int)
    tot_label = np.array([], dtype=int)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()
    end = time.time()
    logger = []
    for num_iter, (images, labels,_) in enumerate(train_dataloader):
        data_time.update(time.time()-end)
        output = model(Variable(images.cuda()))
        loss = criterion(output, Variable(labels.cuda()))
        _, pred = torch.max(output, 1)
        batch_time.update(time.time()-end)
        end = time.time()
        pred = pred.cpu().data.numpy()
        labels = labels.numpy()
        tot_pred = np.append(tot_pred, pred)
        tot_label = np.append(tot_label, labels)
        losses.update(loss.data.cpu().numpy(), len(images))
        accuracy.update(np.equal(pred, labels).sum()/len(labels), len(labels))
        if (num_iter+1) % display == 0:
            correct = np.equal(tot_pred, tot_label).sum()/len(tot_pred)
            print_info = 'Epoch: [{0}][{1}/{2}]\tTime {batch_time.val:3f} ({batch_time.avg:.3f})\t'\
                'Data {data_time.avg:.3f}\t''Loss {loss.avg:.4f}\tAccuray {accuracy.avg:.4f}'.format(
                epoch, num_iter, len(train_dataloader),batch_time=batch_time, data_time=data_time,
                loss=losses, accuracy=accuracy
            )
            print(print_info)
            logger.append(print_info)
    return accuracy.avg, logger

In [None]:
def test(train_dataloader, model, criterion, optimizer, epoch, display):
    val(train_dataloader, model, criterion, optimizer, epoch, display)

In [None]:
def main():
    config_file = './config_dr_2d.json'
    config = None
    with open(config_file) as f:
        config = json.load(f)
    print('\n')
    print('====> parse options:')
    print(config)
    print('\n')
    
    train_file = os.path.join(os.path.join(config["out_data_path"], 'train', 'flags.txt')) if config["train_list_file"]=="" else config["train_list_file"]
    val_file = os.path.join(os.path.join(config["out_data_path"], 'val', 'flags.txt')) if config["val_list_file"]=="" else config["val_list_file"]
    print('training flags file path:\t{}'.format(train_file))
    print('validation flags file path:\t{}'.format(val_file))
    
    print('====> create output model path:\t')
    os.makedirs(config["model_dir"], exist_ok=True)
    time_stamp = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
    time_stamp = config["export_name"]
    model_dir = os.path.join(config["model_dir"], 'dr_cls_{}'.format(time_stamp))
    os.makedirs(model_dir, exist_ok=True)
    
    print('====> building model:\t')
    
    model = DRModel('rsn34', config['scale'], config["num_classes"])
    initial_cls_weights(model)
    pretrained_weights = config['weight']
    if pretrained_weights is not None:
        model.load_state_dict(torch.load(pretrained_weights))
    
    criterion = nn.CrossEntropyLoss().cuda()
    
    batch_size = config['batch_size']
    num_workers = config['num_workers']
    if config['phase'] == 'train':
        train_ds = BinClsDataSet(train_file)
        val_ds = BinClsDataSet(val_file)    
        train_dataloader = DataLoader(train_ds, batch_size=batch_size, 
                                     shuffle=True, num_workers=num_workers, 
                                     pin_memory=True)
        val_dataloader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, 
                                   num_workers=num_workers, pin_memory=False)
        best_acc = 0.6
        epochs = config['epoch']
        display = config['display']
        for epoch in range(epochs):
            if epoch < config['fix']:
                lr = config['lr']
            else:
                lr = config['lr'] * (0.1 ** (epoch//config['step']))
            mom = config['mom']
            wd = config['wd']
            optimizer = None
            if config['optimizer'] == 'sgd':
                optimizer = optim.SGD([{'params': model.parameters()}], 
                                      lr=lr, momentum=mom, weight_decay=wd, nesterov=True)
            print('====> train:\t')
            _, _ = train(train_dataloader, nn.DataParallel(model).cuda(), criterion, optimizer, epoch, display)
            print('====> validate:\t')
            acc, logger = val(val_dataloader, nn.DataParallel(model).cuda(), criterion, optimizer, epoch, display)
            print('val acc:\t{:.3f}'.format(acc))
            if acc > best_acc:
                print('\ncurrent best accuracy is: {}\n'.format(acc))
                best_acc = acc
                saved_model_name = os.path.join(model_dir, 'ct_pos_recognition_{:04d}_best.pth'.format(epoch))
                torch.save(model.cpu().state_dict(), saved_model_name)
                print('====> save model:\t{}'.format(saved_model_name))
    elif config['phase'] == 'test':
        print('====> begin to test:')
        test_file = os.path.join(os.path.join(config["out_data_path"], 'test', 'flags.txt')) if config["test_list_file"]=="" else config["test_list_file"]
        print('test flags file path:\t{}'.format(test_file))
        test_ds = BinClsDataSet(test_file)
        test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, 
                                   num_workers=num_workers, pin_memory=False)
        acc, logger = val(test_dataloader, nn.DataParallel(model).cuda(), criterion, 0, 10)
        print('\t====> test accuracy is {:.3f}'.format(acc))
        print('====> end to test!')

In [None]:
main()



====> parse options:
{'train_list_file': '/data/zhangwd/data/examples/dr_deformable/train_label.txt', 'val_list_file': '/data/zhangwd/data/examples/dr_deformable/val_label.txt', 'test_list_file': '/data/zhangwd/data/examples/dr_deformable/test_label.txt', 'model_dir': './model', 'num_classes': 2, 'scale': 1024, 'phase': 'train', 'model': 'resnet34', 'weight': None, 'lr': 0.001, 'mom': 0.9, 'wd': 0.0001, 'fix': 50, 'step': 20, 'epoch': 120, 'display': 100, 'num_workers': 8, 'batch_size': 8, 'dim_z': 128, 'dim_x': 128, 'optimizer': 'sgd'}


training flags file path:	/data/zhangwd/data/examples/dr_deformable/train_label.txt
validation flags file path:	/data/zhangwd/data/examples/dr_deformable/val_label.txt
====> create output model path:	
====> building model:	
====> train:	
Epoch: [0][99/882]	Time 0.276423 (0.366)	Data 0.174	Loss 0.5560	Accuray 0.7650
Epoch: [0][199/882]	Time 0.275603 (0.324)	Data 0.172	Loss 0.5591	Accuray 0.7681
Epoch: [0][299/882]	Time 0.285715 (0.310)	Data 0.172	Los

Epoch: [8][399/882]	Time 0.282635 (0.287)	Data 0.176	Loss 0.3084	Accuray 0.8703
Epoch: [8][499/882]	Time 0.299967 (0.287)	Data 0.175	Loss 0.3096	Accuray 0.8672
Epoch: [8][599/882]	Time 0.293870 (0.287)	Data 0.176	Loss 0.3102	Accuray 0.8671
Epoch: [8][699/882]	Time 0.288865 (0.287)	Data 0.176	Loss 0.3166	Accuray 0.8664
Epoch: [8][799/882]	Time 0.297126 (0.287)	Data 0.176	Loss 0.3186	Accuray 0.8653
====> validate:	
Epoch: [8][99/247]	Time 0.086456 (0.097)	Data 0.015	Loss 0.0559	Accuray 0.9925
Epoch: [8][199/247]	Time 0.087555 (0.092)	Data 0.011	Loss 0.0719	Accuray 0.9869
val acc:	0.876
====> train:	
Epoch: [9][99/882]	Time 0.285372 (0.292)	Data 0.182	Loss 0.2947	Accuray 0.8775
Epoch: [9][199/882]	Time 0.295214 (0.288)	Data 0.178	Loss 0.3158	Accuray 0.8650
Epoch: [9][299/882]	Time 0.272590 (0.287)	Data 0.177	Loss 0.3140	Accuray 0.8650
Epoch: [9][399/882]	Time 0.277684 (0.287)	Data 0.176	Loss 0.3042	Accuray 0.8722
Epoch: [9][499/882]	Time 0.278080 (0.286)	Data 0.176	Loss 0.3039	Accuray 0.8

Epoch: [17][799/882]	Time 0.286314 (0.288)	Data 0.175	Loss 0.2758	Accuray 0.8864
====> validate:	
Epoch: [17][99/247]	Time 0.090070 (0.097)	Data 0.015	Loss 0.0418	Accuray 0.9962
Epoch: [17][199/247]	Time 0.092690 (0.093)	Data 0.011	Loss 0.0624	Accuray 0.9850
val acc:	0.870
====> train:	
Epoch: [18][99/882]	Time 0.296273 (0.294)	Data 0.182	Loss 0.2700	Accuray 0.8850
Epoch: [18][199/882]	Time 0.296105 (0.290)	Data 0.177	Loss 0.2699	Accuray 0.8850
Epoch: [18][299/882]	Time 0.256583 (0.289)	Data 0.175	Loss 0.2627	Accuray 0.8921
Epoch: [18][399/882]	Time 0.305475 (0.288)	Data 0.174	Loss 0.2781	Accuray 0.8831
Epoch: [18][499/882]	Time 0.290449 (0.288)	Data 0.175	Loss 0.2796	Accuray 0.8832
Epoch: [18][599/882]	Time 0.283760 (0.288)	Data 0.174	Loss 0.2766	Accuray 0.8854
Epoch: [18][699/882]	Time 0.295146 (0.287)	Data 0.174	Loss 0.2749	Accuray 0.8857
Epoch: [18][799/882]	Time 0.293602 (0.287)	Data 0.174	Loss 0.2749	Accuray 0.8862
====> validate:	
Epoch: [18][99/247]	Time 0.090843 (0.097)	Data 0

Epoch: [27][299/882]	Time 0.288041 (0.292)	Data 0.182	Loss 0.2245	Accuray 0.9108
Epoch: [27][399/882]	Time 0.292298 (0.291)	Data 0.182	Loss 0.2310	Accuray 0.9059
Epoch: [27][499/882]	Time 0.279183 (0.291)	Data 0.182	Loss 0.2337	Accuray 0.9045
Epoch: [27][599/882]	Time 0.285994 (0.291)	Data 0.181	Loss 0.2428	Accuray 0.9010
Epoch: [27][699/882]	Time 0.281702 (0.291)	Data 0.181	Loss 0.2436	Accuray 0.9011
Epoch: [27][799/882]	Time 0.288070 (0.291)	Data 0.181	Loss 0.2488	Accuray 0.8961
====> validate:	
Epoch: [27][99/247]	Time 0.089868 (0.097)	Data 0.015	Loss 0.1500	Accuray 0.9450
Epoch: [27][199/247]	Time 0.089880 (0.094)	Data 0.011	Loss 0.1657	Accuray 0.9394
val acc:	0.880
====> train:	
Epoch: [28][99/882]	Time 0.301297 (0.299)	Data 0.187	Loss 0.1932	Accuray 0.9200
Epoch: [28][199/882]	Time 0.296033 (0.295)	Data 0.184	Loss 0.2267	Accuray 0.9081
Epoch: [28][299/882]	Time 0.289647 (0.293)	Data 0.182	Loss 0.2321	Accuray 0.9042
Epoch: [28][399/882]	Time 0.295414 (0.292)	Data 0.181	Loss 0.2339

Epoch: [36][199/247]	Time 0.086831 (0.096)	Data 0.012	Loss 0.7982	Accuray 0.6794
val acc:	0.720
====> train:	
Epoch: [37][99/882]	Time 0.276952 (0.297)	Data 0.188	Loss 0.1721	Accuray 0.9400
Epoch: [37][199/882]	Time 0.281281 (0.293)	Data 0.183	Loss 0.2045	Accuray 0.9219
Epoch: [37][299/882]	Time 0.285455 (0.292)	Data 0.182	Loss 0.2032	Accuray 0.9233
Epoch: [37][399/882]	Time 0.281174 (0.291)	Data 0.181	Loss 0.2011	Accuray 0.9219
Epoch: [37][499/882]	Time 0.281397 (0.291)	Data 0.181	Loss 0.2008	Accuray 0.9197
Epoch: [37][599/882]	Time 0.297622 (0.291)	Data 0.181	Loss 0.1977	Accuray 0.9206
Epoch: [37][699/882]	Time 0.272541 (0.290)	Data 0.181	Loss 0.1986	Accuray 0.9207
Epoch: [37][799/882]	Time 0.288461 (0.290)	Data 0.181	Loss 0.1982	Accuray 0.9209
====> validate:	
Epoch: [37][99/247]	Time 0.093024 (0.101)	Data 0.016	Loss 0.5210	Accuray 0.7975
Epoch: [37][199/247]	Time 0.089463 (0.096)	Data 0.012	Loss 0.5173	Accuray 0.7913
val acc:	0.792
====> train:	
Epoch: [38][99/882]	Time 0.289659 (0

Epoch: [46][599/882]	Time 0.291237 (0.289)	Data 0.180	Loss 0.1094	Accuray 0.9585
Epoch: [46][699/882]	Time 0.288901 (0.289)	Data 0.180	Loss 0.1123	Accuray 0.9566
Epoch: [46][799/882]	Time 0.289612 (0.289)	Data 0.179	Loss 0.1146	Accuray 0.9561
====> validate:	
Epoch: [46][99/247]	Time 0.093734 (0.110)	Data 0.021	Loss 0.8574	Accuray 0.7025
Epoch: [46][199/247]	Time 0.089549 (0.101)	Data 0.014	Loss 0.8472	Accuray 0.7137
val acc:	0.742
====> train:	
Epoch: [47][99/882]	Time 0.283594 (0.296)	Data 0.188	Loss 0.1004	Accuray 0.9675
Epoch: [47][199/882]	Time 0.271305 (0.292)	Data 0.185	Loss 0.1094	Accuray 0.9663
Epoch: [47][299/882]	Time 0.297302 (0.291)	Data 0.183	Loss 0.1124	Accuray 0.9646
Epoch: [47][399/882]	Time 0.291667 (0.290)	Data 0.183	Loss 0.1074	Accuray 0.9653
Epoch: [47][499/882]	Time 0.287414 (0.290)	Data 0.183	Loss 0.1021	Accuray 0.9675
Epoch: [47][599/882]	Time 0.277996 (0.289)	Data 0.182	Loss 0.1043	Accuray 0.9665
Epoch: [47][699/882]	Time 0.272589 (0.289)	Data 0.179	Loss 0.1074

Epoch: [56][199/882]	Time 0.297921 (0.293)	Data 0.181	Loss 0.0513	Accuray 0.9838
Epoch: [56][299/882]	Time 0.287512 (0.292)	Data 0.179	Loss 0.0542	Accuray 0.9833
Epoch: [56][399/882]	Time 0.279033 (0.290)	Data 0.179	Loss 0.0494	Accuray 0.9847
Epoch: [56][499/882]	Time 0.301554 (0.290)	Data 0.179	Loss 0.0478	Accuray 0.9842
Epoch: [56][599/882]	Time 0.280170 (0.289)	Data 0.179	Loss 0.0474	Accuray 0.9844
Epoch: [56][699/882]	Time 0.289244 (0.289)	Data 0.178	Loss 0.0467	Accuray 0.9845
Epoch: [56][799/882]	Time 0.302483 (0.289)	Data 0.178	Loss 0.0459	Accuray 0.9852
====> validate:	
Epoch: [56][99/247]	Time 0.104277 (0.103)	Data 0.020	Loss 0.2476	Accuray 0.9062
Epoch: [56][199/247]	Time 0.093738 (0.097)	Data 0.014	Loss 0.2706	Accuray 0.9025
val acc:	0.862
====> train:	
Epoch: [57][99/882]	Time 0.285409 (0.297)	Data 0.186	Loss 0.0432	Accuray 0.9888
Epoch: [57][199/882]	Time 0.284065 (0.293)	Data 0.182	Loss 0.0543	Accuray 0.9838
Epoch: [57][299/882]	Time 0.276103 (0.292)	Data 0.180	Loss 0.0544