In [1]:
import pandas as pd
import numpy as np
import random
import easydict
import ast
from tqdm import tqdm_notebook
import os
import matplotlib.pyplot as plt
import sys

from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
import torchvision.models as models
import torch.nn.functional as F

from functions.cus_plot import plot_train_process,plot_different_figs
from functions.helper import *

In [2]:
args=easydict.EasyDict({
    'batch_size':16,
    'epoch':50,
    'data_root_path':'d:/Data/CUB/CUB_200_2011/',
    'csv_path':'save/data/data.csv',
    'lr':0.001,
    'train_log':'save/train_log.txt',
    'train_model':'save/model/deconv_vgg16.pt',
    'train_img':'save/img/train_process.png',
    'class_nums':200,
    'train_img_size':224,
    'weight_decay':0.0001,
    'nesterov':False,
    'check_path':'save/train/check.pkl',
    'continue_train':False
})

# Dataloader

In [3]:
class Loader:
    def __init__(self, args, mode='train'):
        self.data_csv = pd.read_csv(args.csv_path)
        self.mode = mode
        self.args = args
        self.img_size = args.train_img_size
        self.class_nums=args.class_nums

        self.train_csv=self.data_csv[self.data_csv['is_train']==1]
        self.val_csv=self.data_csv[self.data_csv['is_train']==0]
        self.train_csv.reset_index(drop=True,inplace=True)
        self.val_csv.reset_index(drop=True,inplace=True)
        
        if self.mode=='train':
            self.cur_csv=self.train_csv
        else:
            self.cur_csv=self.val_csv

    def __getitem__(self, index):
        item = self.cur_csv.loc[index]

        img_id = item['id']
        path = item['path']
        label = item['cls']
        bbox = item['bbox']

        raw_img = Image.open(self.args.data_root_path + path).convert('RGB')
        img = self.image_transform(img_size=self.img_size, mode=self.mode)(raw_img)

        return img_id, img, label, bbox

    def to_train(self):
        self.mode = 'train'
        self.cur_csv = self.train_csv

    def to_val(self):
        self.mode = 'val'
        self.cur_csv = self.val_csv

    def __len__(self):
        return len(self.cur_csv)
    
    @staticmethod
    def image_transform(img_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], mode='train'):
        if mode == 'train':
            horizontal_flip = 0.5
            vertical_flip = 0.5

            t = [
                transforms.RandomResizedCrop(size=img_size),
                transforms.RandomHorizontalFlip(horizontal_flip),
                transforms.RandomVerticalFlip(vertical_flip),
                transforms.ColorJitter(saturation=0.4, brightness=0.4, hue=0.05),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]

        else:
            t = [
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]

        return transforms.Compose([v for v in t])

# 建模

In [4]:
class Deconv_vgg(nn.Module):
    def __init__(self, pre_trained_vgg, args, inference=False, freeze_vgg=False):
        super(Deconv_vgg, self).__init__()
        self.inference = inference
        self.freeze_vgg = freeze_vgg
        self.class_nums=args.class_nums

        self.features = pre_trained_vgg.features
        self.cls = nn.Sequential(
            nn.Dropout(0.5),
            nn.Conv2d(512, 1024, kernel_size=3, padding=1, dilation=1),  # fc6
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1, dilation=1),  # fc6
            nn.ReLU(True),
            nn.Conv2d(1024, self.class_nums, kernel_size=1, padding=0)  # fc8
        )
        
        self.deconv=nn.Sequential(
            nn.ConvTranspose2d(in_channels=200,out_channels=1024,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=512,out_channels=512,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=512,out_channels=200,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU(True)
        )

        if self.freeze_vgg:
            for param in self.features.parameters():
                param.requires_grad = False

    def forward(self, x):
        if self.inference:
            x.requires_grad_()
            x.retain_grad()

        base = self.features(x)
        avg_pool = F.avg_pool2d(base, kernel_size=3, stride=1, padding=1)
        cam = self.cls(avg_pool)
        
        deconv_cam=self.deconv(cam)
        logits=torch.mean(torch.mean(deconv_cam,dim=2),dim=2)

        if self.inference:
            pass

        return logits,deconv_cam
    
    def norm_cam_2_binary(self, bi_x_grad):
        thd = float(np.percentile(np.sort(bi_x_grad.view(-1).cpu().data.numpy()), 80))
        outline = torch.zeros(bi_x_grad.size())
        high_pos = torch.gt(bi_x_grad, thd)
        outline[high_pos.data] = 1.0
        
        return outline

In [5]:
def get_model(pretrained=True,**kwargs):
    pre_trained_model = models.vgg16(pretrained=pretrained)

    model = Deconv_vgg(pre_trained_vgg=pre_trained_model, **kwargs)
    model.cuda()

    return model

# Loss

In [6]:
def get_loss_func(args):
    return torch.nn.CrossEntropyLoss()

# Optimzer

In [7]:
def get_finetune_optimizer(args, model):
    lr = args.lr
    weight_list = []
    bias_list = []
    last_weight_list = []
    last_bias_list = []
    
    for name, value in model.named_parameters():
        if 'cls' in name:
            if 'weight' in name:
                last_weight_list.append(value)
            elif 'bias' in name:
                last_bias_list.append(value)
        else:
            if 'weight' in name:
                weight_list.append(value)
            elif 'bias' in name:
                bias_list.append(value)

    opt = optim.SGD([{'params': weight_list, 'lr': lr / 10},
                     {'params': bias_list, 'lr': lr / 5},
                     {'params': last_weight_list, 'lr': lr},
                     {'params': last_bias_list, 'lr': lr * 2}], momentum=0.9, weight_decay=args.weight_decay, nesterov=args.nesterov)

    return opt

# 训练

In [8]:
#初始化
torch.cuda.empty_cache()
epoch=0
train_acc_arr=[]
val_acc_arr=[]

#加载数据
dataset=Loader(args=args)
dataloader=DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

#加载模型
model=get_model(args=args)

#加载参数
loss_func=get_loss_func(args=args)
    
#训练
if args.continue_train:
    model.load_state_dict(torch.load(args.train_model))
    check_dict=load_check_point(args=args)
    epoch=check_dict['epoch']
    train_acc_arr=check_dict['train_acc_arr']
    val_acc_arr=check_dict['val_acc_arr']
    
while epoch <args.epoch:
    opt = get_finetune_optimizer(args, model)
    
    train_result = []
    train_label = []
    val_result = []
    val_label = []
    
    for step, (img_id, img, label, bbox) in enumerate(dataloader):
        img = img.cuda()
        label = label.cuda()

        logits, cam = model.forward(img)
        loss = loss_func(logits, label)
        acc = cal_acc(logits, label)

        opt.zero_grad()
        loss.backward()
        opt.step()

        print('epoch:{} train loss:{} train acc:{}'.format(epoch, loss, acc))

        train_result.extend(torch.argmax(logits, dim=-1).cpu().data.numpy())
        train_label.extend(label.cpu().data.numpy())
    
    train_acc_arr.append(np.mean(np.array(train_result) == np.array(train_label)))
    
    # validation
    dataset.to_val()
    val_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
    for step, (img_id, img, label, bbox) in enumerate(tqdm_notebook(val_dataloader)):
        img = img.cuda()
        label = label.cuda()

        logits, cam = model.forward(img)
        val_result.extend(torch.argmax(logits, dim=-1).cpu().data.numpy())
        val_label.extend(label.cpu().data.numpy())
        
        if step==0:
            target_cls=torch.argmax(logits,dim=-1)
            
            plot_dict={}
            plot_dict['raw_imgs']=get_raw_imgs_by_id(args,img_id[:5],dataset)
            target_cams=[]
            for i in range(5):
                raw_img_size=plot_dict['raw_imgs'][i].size
                target_cam=cam[i][target_cls[i]].unsqueeze(0).unsqueeze(0).detach().cpu().data
                up_target_cam=F.upsample(target_cam, size=(raw_img_size[1],raw_img_size[0]), mode='bilinear', align_corners=True)
                target_cams.append(up_target_cam.squeeze())
                
            plot_dict['cams']=target_cams
            plot_different_figs(plot_dict)
            

    val_acc_arr.append(np.mean(np.array(val_result) == np.array(val_label)))
    
    if len(val_acc_arr)==1 or val_acc_arr[-1]>=val_acc_arr[-2]:
        torch.save(model.state_dict(), args.train_model)
    
    #plot
    plot_train_process(args,[train_acc_arr,val_acc_arr])
    
    #save check point
    epoch+=1
    save_check_point(args=args,check_dict={
        'epoch':epoch,
        'train_acc_arr':train_acc_arr,
        'val_acc_arr':val_acc_arr
    })
    
    dataset.to_train()

epoch:0 train loss:5.2978973388671875 train acc:0.0
epoch:0 train loss:5.300558567047119 train acc:0.0
epoch:0 train loss:5.302445411682129 train acc:0.0
epoch:0 train loss:5.298092842102051 train acc:0.0
epoch:0 train loss:5.2988057136535645 train acc:0.0
epoch:0 train loss:5.296817302703857 train acc:0.0
epoch:0 train loss:5.2963433265686035 train acc:0.0
epoch:0 train loss:5.3005266189575195 train acc:0.0
epoch:0 train loss:5.300187587738037 train acc:0.0
epoch:0 train loss:5.300678730010986 train acc:0.0
epoch:0 train loss:5.29926872253418 train acc:0.0
epoch:0 train loss:5.2997026443481445 train acc:0.0
epoch:0 train loss:5.297369003295898 train acc:0.0
epoch:0 train loss:5.295777797698975 train acc:0.0
epoch:0 train loss:5.300346851348877 train acc:0.0
epoch:0 train loss:5.302643775939941 train acc:0.0
epoch:0 train loss:5.300783157348633 train acc:0.0
epoch:0 train loss:5.295696258544922 train acc:0.0
epoch:0 train loss:5.300888538360596 train acc:0.0
epoch:0 train loss:5.294212

epoch:0 train loss:5.2984113693237305 train acc:0.0
epoch:0 train loss:5.3000168800354 train acc:0.0
epoch:0 train loss:5.298556804656982 train acc:0.0
epoch:0 train loss:5.2960944175720215 train acc:0.0
epoch:0 train loss:5.2956414222717285 train acc:0.0625
epoch:0 train loss:5.300796985626221 train acc:0.0
epoch:0 train loss:5.295852184295654 train acc:0.0625
epoch:0 train loss:5.296573162078857 train acc:0.0
epoch:0 train loss:5.297729015350342 train acc:0.0
epoch:0 train loss:5.29569149017334 train acc:0.0
epoch:0 train loss:5.29852294921875 train acc:0.0
epoch:0 train loss:5.302338600158691 train acc:0.0
epoch:0 train loss:5.300187110900879 train acc:0.0
epoch:0 train loss:5.300409317016602 train acc:0.0
epoch:0 train loss:5.2977166175842285 train acc:0.0
epoch:0 train loss:5.298351764678955 train acc:0.0
epoch:0 train loss:5.298081398010254 train acc:0.0
epoch:0 train loss:5.2973175048828125 train acc:0.0
epoch:0 train loss:5.2958292961120605 train acc:0.0
epoch:0 train loss:5.29

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))




epoch:1 train loss:5.298020362854004 train acc:0.0
epoch:1 train loss:5.298901557922363 train acc:0.0
epoch:1 train loss:5.299846649169922 train acc:0.0
epoch:1 train loss:5.298932075500488 train acc:0.0
epoch:1 train loss:5.299930095672607 train acc:0.0
epoch:1 train loss:5.298301696777344 train acc:0.0625
epoch:1 train loss:5.299336910247803 train acc:0.0
epoch:1 train loss:5.298492431640625 train acc:0.0
epoch:1 train loss:5.297505855560303 train acc:0.0
epoch:1 train loss:5.30026912689209 train acc:0.0
epoch:1 train loss:5.2978010177612305 train acc:0.0
epoch:1 train loss:5.300914287567139 train acc:0.0
epoch:1 train loss:5.299271583557129 train acc:0.0
epoch:1 train loss:5.2985520362854 train acc:0.0
epoch:1 train loss:5.300684452056885 train acc:0.0
epoch:1 train loss:5.296081066131592 train acc:0.0625
epoch:1 train loss:5.29858922958374 train acc:0.0
epoch:1 train loss:5.297869682312012 train acc:0.0
epoch:1 train loss:5.296184539794922 train acc:0.0
epoch:1 train loss:5.297938

epoch:1 train loss:5.299102306365967 train acc:0.0
epoch:1 train loss:5.2978949546813965 train acc:0.0
epoch:1 train loss:5.302248001098633 train acc:0.0
epoch:1 train loss:5.2970733642578125 train acc:0.0
epoch:1 train loss:5.296627044677734 train acc:0.0
epoch:1 train loss:5.299118518829346 train acc:0.0
epoch:1 train loss:5.300253391265869 train acc:0.0
epoch:1 train loss:5.297296524047852 train acc:0.0
epoch:1 train loss:5.298335075378418 train acc:0.0
epoch:1 train loss:5.294458866119385 train acc:0.0
epoch:1 train loss:5.295547962188721 train acc:0.0
epoch:1 train loss:5.298795700073242 train acc:0.0
epoch:1 train loss:5.300189971923828 train acc:0.0
epoch:1 train loss:5.299688339233398 train acc:0.0
epoch:1 train loss:5.293947219848633 train acc:0.0
epoch:1 train loss:5.298079490661621 train acc:0.0
epoch:1 train loss:5.299742221832275 train acc:0.0
epoch:1 train loss:5.298617362976074 train acc:0.0625
epoch:1 train loss:5.297658443450928 train acc:0.0
epoch:1 train loss:5.29948

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))


epoch:2 train loss:5.298397541046143 train acc:0.0
epoch:2 train loss:5.2989630699157715 train acc:0.0
epoch:2 train loss:5.299751281738281 train acc:0.0
epoch:2 train loss:5.301172256469727 train acc:0.0
epoch:2 train loss:5.299623489379883 train acc:0.0
epoch:2 train loss:5.299569129943848 train acc:0.0
epoch:2 train loss:5.298774719238281 train acc:0.0
epoch:2 train loss:5.295257091522217 train acc:0.0
epoch:2 train loss:5.298083305358887 train acc:0.0
epoch:2 train loss:5.298882484436035 train acc:0.0
epoch:2 train loss:5.294376850128174 train acc:0.0
epoch:2 train loss:5.298440933227539 train acc:0.0
epoch:2 train loss:5.297898292541504 train acc:0.0
epoch:2 train loss:5.29632568359375 train acc:0.0
epoch:2 train loss:5.29530143737793 train acc:0.0
epoch:2 train loss:5.299814224243164 train acc:0.0
epoch:2 train loss:5.296966075897217 train acc:0.0
epoch:2 train loss:5.2960991859436035 train acc:0.0
epoch:2 train loss:5.300508975982666 train acc:0.0
epoch:2 train loss:5.302642822

epoch:2 train loss:5.297835826873779 train acc:0.0
epoch:2 train loss:5.298617839813232 train acc:0.0
epoch:2 train loss:5.296379089355469 train acc:0.125
epoch:2 train loss:5.298405647277832 train acc:0.0
epoch:2 train loss:5.298053741455078 train acc:0.0
epoch:2 train loss:5.29617166519165 train acc:0.0625
epoch:2 train loss:5.300352096557617 train acc:0.0
epoch:2 train loss:5.296236515045166 train acc:0.0625
epoch:2 train loss:5.299568176269531 train acc:0.0
epoch:2 train loss:5.298504829406738 train acc:0.0
epoch:2 train loss:5.297752380371094 train acc:0.0
epoch:2 train loss:5.2988433837890625 train acc:0.0
epoch:2 train loss:5.300076007843018 train acc:0.0
epoch:2 train loss:5.29829740524292 train acc:0.0
epoch:2 train loss:5.2971649169921875 train acc:0.0
epoch:2 train loss:5.297399520874023 train acc:0.0
epoch:2 train loss:5.299853801727295 train acc:0.0
epoch:2 train loss:5.301611423492432 train acc:0.0
epoch:2 train loss:5.295914649963379 train acc:0.0
epoch:2 train loss:5.30

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))


epoch:3 train loss:5.2953386306762695 train acc:0.0625
epoch:3 train loss:5.295817852020264 train acc:0.0
epoch:3 train loss:5.2983503341674805 train acc:0.0
epoch:3 train loss:5.301715850830078 train acc:0.0
epoch:3 train loss:5.30076265335083 train acc:0.0
epoch:3 train loss:5.298306465148926 train acc:0.0
epoch:3 train loss:5.300664901733398 train acc:0.0
epoch:3 train loss:5.300920486450195 train acc:0.0
epoch:3 train loss:5.29803991317749 train acc:0.0
epoch:3 train loss:5.299860000610352 train acc:0.0
epoch:3 train loss:5.300347805023193 train acc:0.0
epoch:3 train loss:5.2979817390441895 train acc:0.0
epoch:3 train loss:5.300286769866943 train acc:0.0
epoch:3 train loss:5.298632621765137 train acc:0.0
epoch:3 train loss:5.299382209777832 train acc:0.0
epoch:3 train loss:5.299435615539551 train acc:0.0
epoch:3 train loss:5.295721054077148 train acc:0.0
epoch:3 train loss:5.301209926605225 train acc:0.0
epoch:3 train loss:5.301753520965576 train acc:0.0
epoch:3 train loss:5.29857

epoch:3 train loss:5.298903465270996 train acc:0.0
epoch:3 train loss:5.2996039390563965 train acc:0.0
epoch:3 train loss:5.295580863952637 train acc:0.0
epoch:3 train loss:5.295234203338623 train acc:0.0
epoch:3 train loss:5.298060894012451 train acc:0.0
epoch:3 train loss:5.299359321594238 train acc:0.0
epoch:3 train loss:5.298300266265869 train acc:0.0
epoch:3 train loss:5.2968573570251465 train acc:0.0
epoch:3 train loss:5.296949863433838 train acc:0.0625
epoch:3 train loss:5.295220375061035 train acc:0.0625
epoch:3 train loss:5.297420978546143 train acc:0.0
epoch:3 train loss:5.299088001251221 train acc:0.0
epoch:3 train loss:5.298323154449463 train acc:0.0
epoch:3 train loss:5.297237873077393 train acc:0.0
epoch:3 train loss:5.298551559448242 train acc:0.0
epoch:3 train loss:5.295932769775391 train acc:0.0
epoch:3 train loss:5.296392440795898 train acc:0.0
epoch:3 train loss:5.300436973571777 train acc:0.0
epoch:3 train loss:5.299074649810791 train acc:0.0
epoch:3 train loss:5.29

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))


epoch:4 train loss:5.301733016967773 train acc:0.0
epoch:4 train loss:5.296097755432129 train acc:0.0
epoch:4 train loss:5.297606468200684 train acc:0.0625
epoch:4 train loss:5.300663948059082 train acc:0.0
epoch:4 train loss:5.298459529876709 train acc:0.0
epoch:4 train loss:5.298678398132324 train acc:0.0
epoch:4 train loss:5.300393581390381 train acc:0.0
epoch:4 train loss:5.301620006561279 train acc:0.0
epoch:4 train loss:5.298038959503174 train acc:0.0
epoch:4 train loss:5.295742511749268 train acc:0.0
epoch:4 train loss:5.298316478729248 train acc:0.0
epoch:4 train loss:5.299464225769043 train acc:0.0
epoch:4 train loss:5.2978434562683105 train acc:0.0
epoch:4 train loss:5.296684265136719 train acc:0.0
epoch:4 train loss:5.300664901733398 train acc:0.0
epoch:4 train loss:5.300563812255859 train acc:0.0
epoch:4 train loss:5.296679496765137 train acc:0.0
epoch:4 train loss:5.294901371002197 train acc:0.0
epoch:4 train loss:5.299020290374756 train acc:0.0
epoch:4 train loss:5.29918

epoch:4 train loss:5.2994818687438965 train acc:0.0
epoch:4 train loss:5.300606727600098 train acc:0.0
epoch:4 train loss:5.299722194671631 train acc:0.0
epoch:4 train loss:5.30001974105835 train acc:0.0
epoch:4 train loss:5.2990217208862305 train acc:0.0
epoch:4 train loss:5.299363613128662 train acc:0.0
epoch:4 train loss:5.2986836433410645 train acc:0.0
epoch:4 train loss:5.2978081703186035 train acc:0.0
epoch:4 train loss:5.298384666442871 train acc:0.0
epoch:4 train loss:5.299614906311035 train acc:0.0
epoch:4 train loss:5.298681735992432 train acc:0.0
epoch:4 train loss:5.300464153289795 train acc:0.0625
epoch:4 train loss:5.300093650817871 train acc:0.0
epoch:4 train loss:5.3009233474731445 train acc:0.0
epoch:4 train loss:5.3002610206604 train acc:0.0
epoch:4 train loss:5.296963214874268 train acc:0.0
epoch:4 train loss:5.296417713165283 train acc:0.0
epoch:4 train loss:5.298097610473633 train acc:0.0
epoch:4 train loss:5.300309181213379 train acc:0.0
epoch:4 train loss:5.29608

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))


epoch:5 train loss:5.297094821929932 train acc:0.0625
epoch:5 train loss:5.299805164337158 train acc:0.0
epoch:5 train loss:5.299352645874023 train acc:0.0
epoch:5 train loss:5.299484729766846 train acc:0.0
epoch:5 train loss:5.299354553222656 train acc:0.0
epoch:5 train loss:5.298647880554199 train acc:0.0
epoch:5 train loss:5.295302867889404 train acc:0.0
epoch:5 train loss:5.297398090362549 train acc:0.0625
epoch:5 train loss:5.29752254486084 train acc:0.0
epoch:5 train loss:5.297981262207031 train acc:0.0
epoch:5 train loss:5.298742294311523 train acc:0.0
epoch:5 train loss:5.29599666595459 train acc:0.0
epoch:5 train loss:5.3003363609313965 train acc:0.0625
epoch:5 train loss:5.298507213592529 train acc:0.0
epoch:5 train loss:5.298564434051514 train acc:0.0
epoch:5 train loss:5.299243927001953 train acc:0.0
epoch:5 train loss:5.298691749572754 train acc:0.0
epoch:5 train loss:5.298595905303955 train acc:0.0
epoch:5 train loss:5.299068450927734 train acc:0.0
epoch:5 train loss:5.3

epoch:5 train loss:5.298140525817871 train acc:0.0
epoch:5 train loss:5.296896457672119 train acc:0.0
epoch:5 train loss:5.300300598144531 train acc:0.0
epoch:5 train loss:5.2984395027160645 train acc:0.0
epoch:5 train loss:5.2971391677856445 train acc:0.0
epoch:5 train loss:5.297495365142822 train acc:0.0
epoch:5 train loss:5.300411701202393 train acc:0.0
epoch:5 train loss:5.299904823303223 train acc:0.0
epoch:5 train loss:5.297513961791992 train acc:0.0
epoch:5 train loss:5.296920299530029 train acc:0.0
epoch:5 train loss:5.299563884735107 train acc:0.0625
epoch:5 train loss:5.296277046203613 train acc:0.0
epoch:5 train loss:5.296623229980469 train acc:0.0
epoch:5 train loss:5.301269054412842 train acc:0.0
epoch:5 train loss:5.29599142074585 train acc:0.0
epoch:5 train loss:5.30143404006958 train acc:0.0
epoch:5 train loss:5.299649715423584 train acc:0.0625
epoch:5 train loss:5.298337936401367 train acc:0.0
epoch:5 train loss:5.300445556640625 train acc:0.0
epoch:5 train loss:5.2998

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\App\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3267, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-bd90bddd2cc3>", line 60, in <module>
    val_result.extend(torch.argmax(logits, dim=-1).cpu().data.numpy())
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\App\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2018, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\App\Anaconda3\lib\site-packages\IPython\core\ultratb.py", line 1095, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "C:\App\Anaconda3\lib\site-packages\IPython\co

KeyboardInterrupt: 

# Result

In [None]:
check_dict=load_check_point(args=args)
epoch=check_dict['epoch']
train_acc_arr=check_dict['train_acc_arr']
val_acc_arr=check_dict['val_acc_arr']

print('train acc:{}'.format(train_acc_arr[-1]))
print('val acc:{}'.format(val_acc_arr[-1]))

plt.imshow(read_one_fig(args.train_img))

# 计算validation数据集的iou

In [None]:
#初始化
torch.cuda.empty_cache()

dataset.to_val()
val_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

iou_result=[]
cls_result=[]
for step, (img_id, img, label, bbox) in enumerate(tqdm_notebook(val_dataloader)):
    img = img.cuda()
    label = label.cuda()

    logits, cam = model.forward(img)
    cls_predict=torch.argmax(logits,dim=-1)
    raw_imgs=get_raw_imgs_by_id(args,img_id,dataset)
    for i in range(len(img_id)):
        target_cam=cam[i][cls_predict[i]].unsqueeze(0).unsqueeze(0)
        raw_img_size=raw_imgs[i].size
        up_cam=F.upsample(target_cam.detach().cpu().data, size=(raw_img_size[1],raw_img_size[0]), mode='bilinear', align_corners=True)
        binary_cam=model.norm_cam_2_binary(up_cam)
        largest_binary_cam=get_max_binary_area(binary_cam.squeeze().numpy())
        gen_bbox=get_bbox_from_binary_cam(largest_binary_cam)
        iou_result.append(get_iou(gen_bbox,[float(x) for x in bbox[i].split(' ')]))
        
    cls_result.extend(cls_predict.cpu().numpy()==label.cpu().numpy())
    
print('iou result on validation is:{}'.format(np.mean((np.array(iou_result)>0.5)*np.array(cls_result))))