In [1]:
import os
import torch
import time
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch.cuda import amp
from scipy.io import savemat
import torch.nn.functional as F
from torchvision import datasets, transforms

from Network.ReActNet_18_Qaw import *
from Network.ReActNet_A_Qaw import *
from Network.utils import *

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
Begin_epoch = 0
Max_epoch = 15#256
Learning_rate = 1e-3
Weight_decay = 0
Momentum = 0.9
Top_k = 5
AMP = False

Dataset_path = './tests/data/CIFAR10/'
Batch_size = 256
Workers = 8
Targetnum = 10

Test_every_iteration = None
Name_suffix = '_step2'
Savemodel_path = './savemodels/'
Record_path = './recorddata/'
if not os.path.exists(Savemodel_path):
    os.mkdir(Savemodel_path)
if not os.path.exists(Record_path):
    os.mkdir(Record_path)

In [3]:
_seed_ = 2023
torch.manual_seed(_seed_)
np.random.seed(_seed_)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
transform_train = transforms.Compose([
    transforms.Pad(4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
 
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

Train_data = datasets.CIFAR10(root=Dataset_path, train=True, download=True, transform=transform_train)
Test_data = datasets.CIFAR10(root=Dataset_path, train=False, download=True, transform=transform_test)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

In [5]:
# net = ResNet18(num_classes=Targetnum, imagenet=False)
#net = Reactnet(num_classes=Targetnum, imagenet=False)
net = ReactnetNoLoop(num_classes=Targetnum)

In [6]:
net = nn.DataParallel(net).cuda()
max_test_acc = 0.
"""if Begin_epoch!=0:
    net.load_state_dict(torch.load(Savemodel_path + f'epoch{Begin_epoch-1}{Name_suffix}.h5'))
    max_test_acc = np.load(Savemodel_path + f'max_acc{Name_suffix}.npy')
    max_test_acc = max_test_acc.item()
else:
    net.load_state_dict(torch.load(Savemodel_path + f'max_acc_step1.h5'))"""


scaler = amp.GradScaler() if AMP else None
Test_top1 = []
Test_topk = []
Test_lossall = []
Epoch_list = []
Iteration_list = []

In [7]:
criterion_train = nn.CrossEntropyLoss()#DistributionLoss()
criterion_test = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)
"""optimizer = torch.optim.Adam([
    {'params' : net.parameters(), 'weight_decay' : Weight_decay, 'initial_lr': Learning_rate}],
    lr = Learning_rate)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/Max_epoch), last_epoch=Begin_epoch-1)"""

"optimizer = torch.optim.Adam([\n    {'params' : net.parameters(), 'weight_decay' : Weight_decay, 'initial_lr': Learning_rate}],\n    lr = Learning_rate)\nlr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/Max_epoch), last_epoch=Begin_epoch-1)"

In [8]:
def test_model(net, max_test_acc, data_loader=test_data_loader, criterion=criterion_test, epoch=None, iteration=None, record=True):
    net.eval()
    test_samples = 0
    test_loss = 0
    test_acc_top1 = 0
    test_acc_topk = 0
    
    with torch.no_grad():
        for img, label in tqdm(data_loader):
            img = img.cuda()
            label = label.cuda()
            label_onehot = F.one_hot(label, Targetnum).float()
            
            out_fr = net(img)
            loss = criterion(out_fr, label)
                
            test_samples += label.numel()
            test_loss += loss.item() * label.numel()

            test_acc_top1 += (out_fr.argmax(1) == label).float().sum().item()
            _, pred = out_fr.topk(Top_k, 1, True, True)
            test_acc_topk += torch.eq(pred, label.view(-1,1)).float().sum().item()
    
    test_loss /= test_samples
    test_acc_top1 /= test_samples
    test_acc_topk /= test_samples

    if test_acc_top1 >= max_test_acc:
        max_test_acc = test_acc_top1
        torch.save(net.state_dict(), Savemodel_path + f'max_acc{Name_suffix}.h5')
        np.save(Savemodel_path + f'max_acc{Name_suffix}.npy', np.array(max_test_acc))

    if record:
        assert epoch is not None, "epoch is None!"
        assert iteration is not None, "iteration is None!"
        
        Epoch_list.append(epoch+1)
        Iteration_list.append(iteration+1)
        Test_top1.append(test_acc_top1)
        Test_topk.append(test_acc_topk)
        Test_lossall.append(test_loss)

        record_data = np.array([Epoch_list, Iteration_list, Test_top1, Test_topk, Test_lossall]).T
        mdic = {f'Record_data':record_data, f'Record_meaning':['Epoch_list', 'Iteration_list', 'Test_top1', f'Test_top{Top_k}', 'Test_loss']}

        savemat(Record_path + f'Test_{Begin_epoch}_{epoch}{Name_suffix}.mat',mdic)
        if os.path.exists(Record_path + f'Test_{Begin_epoch}_{epoch-1}{Name_suffix}.mat'):
            os.remove(Record_path + f'Test_{Begin_epoch}_{epoch-1}{Name_suffix}.mat')

    return test_loss, test_acc_top1, test_acc_topk, max_test_acc

In [9]:
def train_model(net, max_test_acc, epoch, data_loader=train_data_loader, optimizer=optimizer, criterion=criterion_test, scaler=scaler, record=True):
    train_samples = 0
    train_loss = 0
    train_acc_top1 = 0
    train_acc_topk = 0
    
    for i, (img, label) in enumerate(tqdm(data_loader)):
        net.train()
        img = img.cuda()
        label = label.cuda()
        label_onehot = F.one_hot(label, Targetnum).float()
        
        if AMP:
            with amp.autocast():
                out_fr = net(img)
                loss = criterion(out_fr, label)
        else:
            out_fr = net(img)
            loss = criterion(out_fr, label)
            
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()

        train_acc_top1 += (out_fr.argmax(1) == label).float().sum().item()
        _, pred = out_fr.topk(Top_k, 1, True, True)
        train_acc_topk += torch.eq(pred, label.view(-1,1)).float().sum().item()
        
        optimizer.zero_grad()
        if AMP:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()

            parameters_list = []
            for name, p in net.named_parameters():
                if not 'fc' in name:
                    parameters_list.append(p)
            adaptive_clip_grad(parameters_list, clip_factor=0.001)
            
            optimizer.step()

        if Test_every_iteration is not None:
            if (i+1) % Test_every_iteration == 0:
                test_loss, test_acc_top1, test_acc_topk, max_test_acc = test_model(net, max_test_acc, epoch=epoch, iteration=i, record=record)
                print(f'Test_loss: {test_loss:.4f}, Test_acc_top1: {test_acc_top1:.4f}, Test_acc_top{Top_k}: {test_acc_topk:.4f}, Max_test_acc: {max_test_acc:.4f}')
    
    train_loss /= train_samples
    train_acc_top1 /= train_samples
    train_acc_topk /= train_samples

    test_loss, test_acc_top1, test_acc_topk, max_test_acc = test_model(net, max_test_acc, epoch=epoch, iteration=i, record=record)
        
    return train_loss, train_acc_top1, train_acc_topk, test_loss, test_acc_top1, test_acc_topk, max_test_acc

In [11]:
for epoch in range(Begin_epoch, Max_epoch):

    start_time = time.time()
    train_loss, train_acc_top1, train_acc_topk, test_loss, test_acc_top1, test_acc_topk, max_test_acc = train_model(net, max_test_acc, epoch)
    
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
        
    #lr_scheduler.step()

    print(f'''epoch={epoch}, train_acc_top1={train_acc_top1:.4f}, train_acc_top{Top_k}={train_acc_topk:.4f}, train_loss={train_loss:.4f}, test_top1={test_acc_top1:.4f}, test_top{Top_k}={test_acc_topk:.4f}, test_loss={test_loss:.4f}, max_test_acc={max_test_acc:.4f}, total_time={(time.time() - start_time):.4f}, LR={lr:.8f}''')
    
    torch.save(net.state_dict(), Savemodel_path + f'epoch{epoch}{Name_suffix}.h5')
    if os.path.exists(Savemodel_path + f'epoch{epoch-1}{Name_suffix}.h5'):
        os.remove(Savemodel_path + f'epoch{epoch-1}{Name_suffix}.h5')

100%|██████████| 195/195 [00:15<00:00, 12.82it/s]
100%|██████████| 40/40 [00:01<00:00, 35.84it/s]


epoch=0, train_acc_top1=0.3698, train_acc_top5=0.8781, train_loss=1.7076, test_top1=0.4426, test_top5=0.9092, test_loss=1.5568, max_test_acc=0.4426, total_time=16.3500, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.97it/s]
100%|██████████| 40/40 [00:01<00:00, 37.08it/s]


epoch=1, train_acc_top1=0.4435, train_acc_top5=0.9132, train_loss=1.5278, test_top1=0.4760, test_top5=0.9272, test_loss=1.4563, max_test_acc=0.4760, total_time=16.1367, LR=0.00100000


100%|██████████| 195/195 [00:14<00:00, 13.07it/s]
100%|██████████| 40/40 [00:01<00:00, 37.46it/s]


epoch=2, train_acc_top1=0.4780, train_acc_top5=0.9278, train_loss=1.4444, test_top1=0.4817, test_top5=0.9327, test_loss=1.4087, max_test_acc=0.4817, total_time=15.9985, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.99it/s]
100%|██████████| 40/40 [00:01<00:00, 37.30it/s]


epoch=3, train_acc_top1=0.5002, train_acc_top5=0.9351, train_loss=1.3803, test_top1=0.5277, test_top5=0.9422, test_loss=1.3224, max_test_acc=0.5277, total_time=16.0923, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.80it/s]
100%|██████████| 40/40 [00:01<00:00, 36.67it/s]


epoch=4, train_acc_top1=0.5175, train_acc_top5=0.9394, train_loss=1.3400, test_top1=0.5557, test_top5=0.9479, test_loss=1.2694, max_test_acc=0.5557, total_time=16.3331, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.90it/s]
100%|██████████| 40/40 [00:01<00:00, 36.30it/s]


epoch=5, train_acc_top1=0.5311, train_acc_top5=0.9440, train_loss=1.3105, test_top1=0.5228, test_top5=0.9425, test_loss=1.3206, max_test_acc=0.5557, total_time=16.2273, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.83it/s]
100%|██████████| 40/40 [00:01<00:00, 37.41it/s]


epoch=6, train_acc_top1=0.5433, train_acc_top5=0.9459, train_loss=1.2828, test_top1=0.5560, test_top5=0.9513, test_loss=1.2513, max_test_acc=0.5560, total_time=16.2768, LR=0.00100000


100%|██████████| 195/195 [00:14<00:00, 13.57it/s]
100%|██████████| 40/40 [00:00<00:00, 40.23it/s]


epoch=7, train_acc_top1=0.5504, train_acc_top5=0.9488, train_loss=1.2563, test_top1=0.5603, test_top5=0.9503, test_loss=1.2256, max_test_acc=0.5603, total_time=15.3879, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 14.06it/s]
100%|██████████| 40/40 [00:00<00:00, 40.58it/s]


epoch=8, train_acc_top1=0.5641, train_acc_top5=0.9507, train_loss=1.2248, test_top1=0.5708, test_top5=0.9563, test_loss=1.2026, max_test_acc=0.5708, total_time=14.8686, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 14.25it/s]
100%|██████████| 40/40 [00:00<00:00, 40.97it/s]


epoch=9, train_acc_top1=0.5712, train_acc_top5=0.9523, train_loss=1.2081, test_top1=0.5756, test_top5=0.9571, test_loss=1.2018, max_test_acc=0.5756, total_time=14.6809, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 14.38it/s]
100%|██████████| 40/40 [00:00<00:00, 40.53it/s]


epoch=10, train_acc_top1=0.5778, train_acc_top5=0.9537, train_loss=1.1923, test_top1=0.5968, test_top5=0.9593, test_loss=1.1443, max_test_acc=0.5968, total_time=14.5623, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 14.14it/s]
100%|██████████| 40/40 [00:01<00:00, 38.31it/s]


epoch=11, train_acc_top1=0.5817, train_acc_top5=0.9548, train_loss=1.1797, test_top1=0.5847, test_top5=0.9584, test_loss=1.1607, max_test_acc=0.5968, total_time=14.8368, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 13.96it/s]
100%|██████████| 40/40 [00:00<00:00, 40.35it/s]


epoch=12, train_acc_top1=0.5887, train_acc_top5=0.9559, train_loss=1.1630, test_top1=0.5921, test_top5=0.9606, test_loss=1.1381, max_test_acc=0.5968, total_time=14.9647, LR=0.00100000


100%|██████████| 195/195 [00:13<00:00, 14.03it/s]
100%|██████████| 40/40 [00:00<00:00, 40.78it/s]


epoch=13, train_acc_top1=0.5928, train_acc_top5=0.9582, train_loss=1.1476, test_top1=0.6024, test_top5=0.9606, test_loss=1.1223, max_test_acc=0.6024, total_time=14.8869, LR=0.00100000


100%|██████████| 195/195 [00:15<00:00, 12.86it/s]
100%|██████████| 40/40 [00:01<00:00, 36.56it/s]


epoch=14, train_acc_top1=0.5949, train_acc_top5=0.9579, train_loss=1.1433, test_top1=0.5753, test_top5=0.9579, test_loss=1.2022, max_test_acc=0.6024, total_time=16.2629, LR=0.00100000


In [None]:
net.load_state_dict(torch.load(Savemodel_path + f'max_acc{Name_suffix}.h5'))

In [12]:
Confusion_Matrix = torch.zeros((Targetnum, Targetnum))
net.eval()
with torch.no_grad():
    for img, label in tqdm(test_data_loader):
        img = img.cuda()
        label = label.cuda()
        out_fr = net(img)
        guess = out_fr.argmax(1)
        for j in range(len(label)):
            Confusion_Matrix[label[j],guess[j]] += 1
acc = Confusion_Matrix.diag()
acc = acc.sum()/Confusion_Matrix.sum()
print(f'Confusion_Matrix = {Confusion_Matrix}')
print(f'acc = {acc}')

100%|██████████| 40/40 [00:01<00:00, 22.41it/s]

Confusion_Matrix = tensor([[462., 147.,  40.,  20.,   4.,  15.,  25.,   3., 234.,  50.],
        [ 12., 870.,   0.,   5.,   0.,   1.,  10.,   3.,  20.,  79.],
        [ 94.,  56., 312., 105.,  51.,  85., 223.,  16.,  47.,  11.],
        [ 13.,  41.,  48., 456.,  15., 182., 157.,  14.,  42.,  32.],
        [ 26.,  26.,  87.,  75., 322.,  57., 282.,  60.,  47.,  18.],
        [  6.,  24.,  32., 231.,  28., 541.,  79.,  21.,  19.,  19.],
        [  8.,  26.,  22.,  61.,   3.,  22., 835.,   1.,  16.,   6.],
        [ 31.,  50.,  27.,  71.,  70., 123.,  53., 492.,  18.,  65.],
        [ 32., 112.,   6.,  13.,   2.,   3.,   9.,   3., 777.,  43.],
        [ 20., 219.,   3.,  10.,   1.,   2.,  13.,   2.,  44., 686.]])
acc = 0.5752999782562256



