# 載入套件、設定function、載入模型

In [14]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from dy_resnet import *
import torch.nn.functional as F
import random, time, torchprofile
import torchvision.models as models
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# Define a custom dataset class
class ImageNetDataset(Dataset):
    def __init__(self, data_file, root_dir, transform=None):
        """
        Args:
            data_file (string): Path to the file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        # Read the data file
        with open(data_file, 'r') as file:
            self.data = file.readlines()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Process each line
        line = self.data[idx].strip()
        img_path, label = line.split()
        # Complete image path
        img_path = os.path.join(self.root_dir, img_path)
        image = Image.open(img_path).convert('RGB')
        label = int(label)
        
        if self.transform:
            image = self.transform(image)

        return image, label

class RandomChannelDropout:
    def __init__(self):
        self.channel_combinations = [
            (0, 1, 2),  # RGB
            (0, 1),     # RG
            (0, 2),     # RB
            (1, 2),     # GB
            (0,),       # R
            (1,),       # G
            (2,)        # B
        ]

    def __call__(self, img):
        channels = list(img.split())
        chosen_combination = random.choice(self.channel_combinations)
        combined_img = Image.merge('RGB', [channels[i] if i in chosen_combination else Image.new('L', img.size) for i in range(3)])
        return combined_img

# Define transformations

#基礎影像處理
basic_augmentations = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor()
])

train_augmentations = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(1),
    transforms.RandomVerticalFlip(0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
    transforms.RandomRotation(15),
    RandomChannelDropout(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_augmentations = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_augmentations = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

class DynamicConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(DynamicConv2D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias

        # 定義動態卷積核權重和偏置
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # 根據輸入的channel數量進行權重調整
        current_in_channels = x.size(1)
        if current_in_channels < self.in_channels:
            weight = self.weight[:, :current_in_channels, :, :]
        else:
            weight = self.weight
        
        # 應用卷積操作
        return F.conv2d(x, weight, self.bias, self.stride, self.padding)

is_train = True
model_path = "best_1.pt"
if is_train:
    # 選擇ResNet-18作為模型
    model = models.resnet18(pretrained=False)
    model.conv1 = DynamicConv2D(in_channels=3, out_channels=64, kernel_size=7, padding=1)
    model.fc = torch.nn.Linear(model.fc.in_features, 50)
    model.to(device)
else:
    model = models.resnet18(pretrained=False)
    model.conv1 = DynamicConv2D(in_channels=3, out_channels=64, kernel_size=7, padding=1)
    model.fc = torch.nn.Linear(model.fc.in_features, 50)
    model = torch.load(model_path)
    model.eval()


Exception in thread Thread-7:
Traceback (most recent call last):
  File "c:\Users\wen\anaconda3\envs\wen\lib\threading.py", line 980, in _bootstrap_inner
    self.run()
  File "c:\Users\wen\anaconda3\envs\wen\lib\site-packages\ipykernel\ipkernel.py", line 761, in run_closure
    _threading_Thread_run(self)
  File "c:\Users\wen\anaconda3\envs\wen\lib\threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "c:\Users\wen\anaconda3\envs\wen\lib\subprocess.py", line 1495, in _readerthread
    buffer.append(fh.read())
  File "c:\Users\wen\anaconda3\envs\wen\lib\codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa4 in position 6: invalid start byte


# Load Dataset

In [9]:
# Create dataset instances
train_dataset = ImageNetDataset(data_file='train.txt', root_dir='', transform=train_augmentations)
val_dataset = ImageNetDataset(data_file='val.txt', root_dir='', transform=val_augmentations)
test_dataset = ImageNetDataset(data_file='test.txt', root_dir='', transform=test_augmentations)

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 訓練DynamicConv2D ResNet18

In [4]:
epochs = 30
# batch_size = 32
# test_batch_size = 32
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
net_name = 'dy_resnet'
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

def adjust_lr(optimizer, epoch):
    if epoch in [epochs*0.5, epochs*0.75, epochs*0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Change lr:'+str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, target) in train_loader_iter:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss=loss.item(), accuracy=100. * train_acc.item() / train_loader_len)
    
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))



def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total=len(val_loader), desc="Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss=test_loss / len(val_loader.dataset), accuracy=100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))

    return accuracy

if is_train:
    best_val_acc = 0.
    for i in range(epochs):
        train(i + 1)
        temp_acc = val(i + 1)
        if temp_acc > best_val_acc:
            best_val_acc = temp_acc
            torch.save(model.state_dict(), 'best_1.pt')
            print('Best Accuracy: {:.2f}%'.format(best_val_acc))

    print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))

Training Epoch #1: 100%|██████████| 990/990 [03:36<00:00,  4.58it/s, accuracy=5.02, loss=3.75] 


Train Epoch: 1, Loss: 3.793766, Accuracy: 5.02%


Validation Epoch #1: 100%|██████████| 8/8 [00:00<00:00,  8.15it/s, accuracy=9.33, loss=3.6]  


Validation Set: Average Loss: 3.5951, Accuracy: 9.33%
Best Accuracy: 9.33%


Training Epoch #2: 100%|██████████| 990/990 [03:35<00:00,  4.60it/s, accuracy=8.63, loss=3.73] 


Train Epoch: 2, Loss: 3.564647, Accuracy: 8.63%


Validation Epoch #2: 100%|██████████| 8/8 [00:01<00:00,  7.85it/s, accuracy=15.3, loss=3.26] 


Validation Set: Average Loss: 3.2553, Accuracy: 15.33%
Best Accuracy: 15.33%


Training Epoch #3: 100%|██████████| 990/990 [03:37<00:00,  4.56it/s, accuracy=11.7, loss=3.11] 


Train Epoch: 3, Loss: 3.404045, Accuracy: 11.66%


Validation Epoch #3: 100%|██████████| 8/8 [00:01<00:00,  7.26it/s, accuracy=16.7, loss=3.08] 


Validation Set: Average Loss: 3.0778, Accuracy: 16.67%
Best Accuracy: 16.67%


Training Epoch #4: 100%|██████████| 990/990 [03:02<00:00,  5.43it/s, accuracy=15, loss=3.13]  


Train Epoch: 4, Loss: 3.239766, Accuracy: 15.02%


Validation Epoch #4: 100%|██████████| 8/8 [00:00<00:00, 12.49it/s, accuracy=18.4, loss=2.99] 


Validation Set: Average Loss: 2.9877, Accuracy: 18.44%
Best Accuracy: 18.44%


Training Epoch #5: 100%|██████████| 990/990 [02:22<00:00,  6.97it/s, accuracy=18.7, loss=3.02]


Train Epoch: 5, Loss: 3.050278, Accuracy: 18.73%


Validation Epoch #5: 100%|██████████| 8/8 [00:00<00:00, 12.53it/s, accuracy=24, loss=2.63]   


Validation Set: Average Loss: 2.6294, Accuracy: 24.00%
Best Accuracy: 24.00%


Training Epoch #6: 100%|██████████| 990/990 [02:50<00:00,  5.80it/s, accuracy=22.2, loss=2.99]


Train Epoch: 6, Loss: 2.871903, Accuracy: 22.25%


Validation Epoch #6: 100%|██████████| 8/8 [00:01<00:00,  7.60it/s, accuracy=26.4, loss=2.5]  


Validation Set: Average Loss: 2.4977, Accuracy: 26.44%
Best Accuracy: 26.44%


Training Epoch #7: 100%|██████████| 990/990 [02:46<00:00,  5.94it/s, accuracy=25.8, loss=2.59]


Train Epoch: 7, Loss: 2.720708, Accuracy: 25.77%


Validation Epoch #7: 100%|██████████| 8/8 [00:01<00:00,  7.17it/s, accuracy=34.2, loss=2.4]  


Validation Set: Average Loss: 2.4013, Accuracy: 34.22%
Best Accuracy: 34.22%


Training Epoch #8: 100%|██████████| 990/990 [03:41<00:00,  4.47it/s, accuracy=28.5, loss=2.85]


Train Epoch: 8, Loss: 2.603299, Accuracy: 28.55%


Validation Epoch #8: 100%|██████████| 8/8 [00:01<00:00,  7.63it/s, accuracy=38.4, loss=2.16] 


Validation Set: Average Loss: 2.1550, Accuracy: 38.44%
Best Accuracy: 38.44%


Training Epoch #9: 100%|██████████| 990/990 [02:57<00:00,  5.57it/s, accuracy=30.5, loss=2.29]


Train Epoch: 9, Loss: 2.512499, Accuracy: 30.48%


Validation Epoch #9: 100%|██████████| 8/8 [00:01<00:00,  6.76it/s, accuracy=40.2, loss=2.03] 


Validation Set: Average Loss: 2.0336, Accuracy: 40.22%
Best Accuracy: 40.22%


Training Epoch #10: 100%|██████████| 990/990 [03:15<00:00,  5.07it/s, accuracy=32.6, loss=2.68]


Train Epoch: 10, Loss: 2.426343, Accuracy: 32.57%


Validation Epoch #10: 100%|██████████| 8/8 [00:01<00:00,  7.79it/s, accuracy=35.6, loss=2.23] 


Validation Set: Average Loss: 2.2339, Accuracy: 35.56%


Training Epoch #11: 100%|██████████| 990/990 [03:23<00:00,  4.87it/s, accuracy=34.3, loss=2.51]


Train Epoch: 11, Loss: 2.348786, Accuracy: 34.31%


Validation Epoch #11: 100%|██████████| 8/8 [00:00<00:00, 11.95it/s, accuracy=44.9, loss=1.93] 


Validation Set: Average Loss: 1.9279, Accuracy: 44.89%
Best Accuracy: 44.89%


Training Epoch #12: 100%|██████████| 990/990 [02:22<00:00,  6.97it/s, accuracy=35.7, loss=2.33]


Train Epoch: 12, Loss: 2.290356, Accuracy: 35.71%


Validation Epoch #12: 100%|██████████| 8/8 [00:00<00:00, 12.43it/s, accuracy=38.4, loss=2.06] 


Validation Set: Average Loss: 2.0624, Accuracy: 38.44%


Training Epoch #13: 100%|██████████| 990/990 [02:21<00:00,  7.01it/s, accuracy=37.5, loss=2.38]


Train Epoch: 13, Loss: 2.232722, Accuracy: 37.49%


Validation Epoch #13: 100%|██████████| 8/8 [00:00<00:00, 12.19it/s, accuracy=44.7, loss=1.8]  


Validation Set: Average Loss: 1.7950, Accuracy: 44.67%


Training Epoch #14: 100%|██████████| 990/990 [03:18<00:00,  4.98it/s, accuracy=38.6, loss=2.24]


Train Epoch: 14, Loss: 2.187071, Accuracy: 38.56%


Validation Epoch #14: 100%|██████████| 8/8 [00:01<00:00,  7.34it/s, accuracy=46.4, loss=1.85] 


Validation Set: Average Loss: 1.8467, Accuracy: 46.44%
Best Accuracy: 46.44%
Change lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [03:35<00:00,  4.60it/s, accuracy=46.3, loss=1.78]


Train Epoch: 15, Loss: 1.886494, Accuracy: 46.34%


Validation Epoch #15: 100%|██████████| 8/8 [00:01<00:00,  7.66it/s, accuracy=60, loss=1.31]   


Validation Set: Average Loss: 1.3105, Accuracy: 60.00%
Best Accuracy: 60.00%


Training Epoch #16: 100%|██████████| 990/990 [03:35<00:00,  4.59it/s, accuracy=48.6, loss=2.03]


Train Epoch: 16, Loss: 1.793619, Accuracy: 48.63%


Validation Epoch #16: 100%|██████████| 8/8 [00:01<00:00,  7.84it/s, accuracy=60.4, loss=1.28] 


Validation Set: Average Loss: 1.2820, Accuracy: 60.44%
Best Accuracy: 60.44%


Training Epoch #17: 100%|██████████| 990/990 [03:36<00:00,  4.58it/s, accuracy=49.6, loss=1.48]


Train Epoch: 17, Loss: 1.747732, Accuracy: 49.62%


Validation Epoch #17: 100%|██████████| 8/8 [00:01<00:00,  7.64it/s, accuracy=60.2, loss=1.26] 


Validation Set: Average Loss: 1.2571, Accuracy: 60.22%


Training Epoch #18: 100%|██████████| 990/990 [03:38<00:00,  4.53it/s, accuracy=50.4, loss=1.84]


Train Epoch: 18, Loss: 1.716956, Accuracy: 50.42%


Validation Epoch #18: 100%|██████████| 8/8 [00:01<00:00,  7.87it/s, accuracy=61.1, loss=1.24] 


Validation Set: Average Loss: 1.2380, Accuracy: 61.11%
Best Accuracy: 61.11%


Training Epoch #19: 100%|██████████| 990/990 [03:36<00:00,  4.57it/s, accuracy=51.2, loss=1.52]


Train Epoch: 19, Loss: 1.694842, Accuracy: 51.16%


Validation Epoch #19: 100%|██████████| 8/8 [00:01<00:00,  7.94it/s, accuracy=62, loss=1.25]   


Validation Set: Average Loss: 1.2507, Accuracy: 62.00%
Best Accuracy: 62.00%


Training Epoch #20: 100%|██████████| 990/990 [03:10<00:00,  5.19it/s, accuracy=51.7, loss=1.61]


Train Epoch: 20, Loss: 1.668227, Accuracy: 51.68%


Validation Epoch #20: 100%|██████████| 8/8 [00:01<00:00,  7.02it/s, accuracy=62.4, loss=1.25] 


Validation Set: Average Loss: 1.2542, Accuracy: 62.44%
Best Accuracy: 62.44%


Training Epoch #21: 100%|██████████| 990/990 [03:20<00:00,  4.95it/s, accuracy=52.6, loss=2.21] 


Train Epoch: 21, Loss: 1.640695, Accuracy: 52.56%


Validation Epoch #21: 100%|██████████| 8/8 [00:01<00:00,  7.60it/s, accuracy=61.3, loss=1.23] 


Validation Set: Average Loss: 1.2330, Accuracy: 61.33%


Training Epoch #22: 100%|██████████| 990/990 [03:36<00:00,  4.57it/s, accuracy=52.9, loss=1.26] 


Train Epoch: 22, Loss: 1.629080, Accuracy: 52.89%


Validation Epoch #22: 100%|██████████| 8/8 [00:01<00:00,  7.58it/s, accuracy=63.6, loss=1.21] 


Validation Set: Average Loss: 1.2082, Accuracy: 63.56%
Best Accuracy: 63.56%


Training Epoch #23: 100%|██████████| 990/990 [03:37<00:00,  4.56it/s, accuracy=53.2, loss=1.72]


Train Epoch: 23, Loss: 1.609286, Accuracy: 53.22%


Validation Epoch #23: 100%|██████████| 8/8 [00:01<00:00,  7.87it/s, accuracy=63.6, loss=1.17] 


Validation Set: Average Loss: 1.1740, Accuracy: 63.56%


Training Epoch #24: 100%|██████████| 990/990 [03:35<00:00,  4.59it/s, accuracy=54, loss=1.42]  


Train Epoch: 24, Loss: 1.586003, Accuracy: 53.99%


Validation Epoch #24: 100%|██████████| 8/8 [00:01<00:00,  7.68it/s, accuracy=64.2, loss=1.19] 


Validation Set: Average Loss: 1.1927, Accuracy: 64.22%
Best Accuracy: 64.22%


Training Epoch #25: 100%|██████████| 990/990 [03:21<00:00,  4.91it/s, accuracy=54, loss=1.67]  


Train Epoch: 25, Loss: 1.581621, Accuracy: 53.97%


Validation Epoch #25: 100%|██████████| 8/8 [00:00<00:00, 12.54it/s, accuracy=61.8, loss=1.19] 


Validation Set: Average Loss: 1.1911, Accuracy: 61.78%


Training Epoch #26: 100%|██████████| 990/990 [02:21<00:00,  7.00it/s, accuracy=54.4, loss=1.51]


Train Epoch: 26, Loss: 1.562532, Accuracy: 54.42%


Validation Epoch #26: 100%|██████████| 8/8 [00:00<00:00, 12.26it/s, accuracy=65.3, loss=1.11] 


Validation Set: Average Loss: 1.1124, Accuracy: 65.33%
Best Accuracy: 65.33%


Training Epoch #27: 100%|██████████| 990/990 [02:22<00:00,  6.94it/s, accuracy=55.1, loss=1.52]


Train Epoch: 27, Loss: 1.544095, Accuracy: 55.10%


Validation Epoch #27: 100%|██████████| 8/8 [00:00<00:00, 12.55it/s, accuracy=66, loss=1.13]   


Validation Set: Average Loss: 1.1327, Accuracy: 66.00%
Best Accuracy: 66.00%


Training Epoch #28: 100%|██████████| 990/990 [02:22<00:00,  6.96it/s, accuracy=55.5, loss=1.75]


Train Epoch: 28, Loss: 1.528778, Accuracy: 55.46%


Validation Epoch #28: 100%|██████████| 8/8 [00:00<00:00, 12.49it/s, accuracy=64.2, loss=1.14] 


Validation Set: Average Loss: 1.1364, Accuracy: 64.22%


Training Epoch #29: 100%|██████████| 990/990 [02:45<00:00,  6.00it/s, accuracy=55.6, loss=1.4]  


Train Epoch: 29, Loss: 1.522047, Accuracy: 55.64%


Validation Epoch #29: 100%|██████████| 8/8 [00:00<00:00, 12.36it/s, accuracy=65.1, loss=1.12] 


Validation Set: Average Loss: 1.1225, Accuracy: 65.11%


Training Epoch #30: 100%|██████████| 990/990 [02:34<00:00,  6.40it/s, accuracy=56.1, loss=1.98] 


Train Epoch: 30, Loss: 1.503509, Accuracy: 56.08%


Validation Epoch #30: 100%|██████████| 8/8 [00:00<00:00, 12.31it/s, accuracy=63.6, loss=1.13] 

Validation Set: Average Loss: 1.1327, Accuracy: 63.56%
Final Best Accuracy: 66.00%





# 載入DynamicConv2D ResNet18

In [17]:
model = models.resnet18(pretrained=False)
model.conv1 = DynamicConv2D(in_channels=3, out_channels=64, kernel_size=7, padding=1)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.load_state_dict(torch.load("best_1.pt"))
model.to(device)
model.eval()



ResNet(
  (conv1): DynamicConv2D()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [18]:
def select_channels(image, channels='RGB'):
    if channels == 'R':
        return image[0, :, :].unsqueeze(0)
    elif channels == 'G':
        return image[1, :, :].unsqueeze(0)
    elif channels == 'B':
        return image[2, :, :].unsqueeze(0)
    elif channels == 'RG':
        return image[0:2, :, :]
    elif channels == 'RB':
        return torch.stack([image[0, :, :], image[2, :, :]], dim=0)
    elif channels == 'GB':
        return image[1:, :, :]
    elif channels == 'RGB':
        return image
    else:
        raise ValueError('Invalid channel selection')

In [19]:
def rgb_dataloader(set_channel = 'RGB'):
    test_augmentations = transforms.Compose([
        transforms.Resize(96),
        transforms.CenterCrop(96),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.Lambda(lambda x: select_channels(x, set_channel))  # Change 'RGB' to other combinations as needed
    ])

    test_dataset = ImageNetDataset(data_file='test.txt', root_dir='', transform=test_augmentations)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    return test_loader

# 測試集驗證

In [21]:
def test(model, rgb_set):
    model.eval()
    correct = 0
    test_loader = rgb_dataloader(set_channel = rgb_set)
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total=len(test_loader), desc="Testing")
    
    all_preds = []
    all_targets = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for data, target in test_loader_iter:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            test_loader_iter.set_postfix(accuracy=100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    accuracy = 100. * correct / test_loader_len
    
    # 計算 Precision, Recall 和 F1-score
    precision = 100. * precision_score(all_targets, all_preds, average='macro')
    recall = 100. * recall_score(all_targets, all_preds, average='macro')
    f1 = 100. * f1_score(all_targets, all_preds, average='macro')
    
    # 計算 FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).to(device))
    
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

rgb_list = ["RGB", "RG", "RB", "GB", "R", "G", "B"]
for rgb in rgb_list:
    resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(model,rgb)
    print(f"RGB Set: {rgb:s}, Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")
    # resnet18_acc, resnet18_elapsed_time = test(model,rgb)
    # print(f"RGB Set: {rgb:s}, Accuracy: {resnet18_acc:.2f}%, Elapsed Time: {resnet18_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:00<00:00,  9.87it/s, accuracy=69.6]


RGB Set: RGB, Accuracy: 69.56%, Precision: 71.25%, Recall: 69.56%, F1 Score: 68.36%, FLOPS: 1274698880, Elapsed Time: 0.81 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 12.19it/s, accuracy=57.3]


RGB Set: RG, Accuracy: 57.33%, Precision: 61.72%, Recall: 57.33%, F1 Score: 56.81%, FLOPS: 1248155776, Elapsed Time: 0.66 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.59it/s, accuracy=56.9]


RGB Set: RB, Accuracy: 56.89%, Precision: 60.93%, Recall: 56.89%, F1 Score: 56.26%, FLOPS: 1248155776, Elapsed Time: 0.69 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.61it/s, accuracy=53.8]


RGB Set: GB, Accuracy: 53.78%, Precision: 57.50%, Recall: 53.78%, F1 Score: 52.75%, FLOPS: 1248155776, Elapsed Time: 0.76 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.14it/s, accuracy=24.4]
  _warn_prf(average, modifier, msg_start, len(result))


RGB Set: R, Accuracy: 24.44%, Precision: 34.60%, Recall: 24.44%, F1 Score: 21.76%, FLOPS: 1221612672, Elapsed Time: 0.72 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.87it/s, accuracy=22.2]
  _warn_prf(average, modifier, msg_start, len(result))


RGB Set: G, Accuracy: 22.22%, Precision: 31.10%, Recall: 22.22%, F1 Score: 20.10%, FLOPS: 1221612672, Elapsed Time: 0.74 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.74it/s, accuracy=21.3]

RGB Set: B, Accuracy: 21.33%, Precision: 26.59%, Recall: 21.33%, F1 Score: 17.88%, FLOPS: 1221612672, Elapsed Time: 0.75 seconds



  _warn_prf(average, modifier, msg_start, len(result))


# 載入ResNet18

In [7]:
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.to(device)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# 訓練ResNet18

In [10]:
epochs = 30
# batch_size = 32
# test_batch_size = 32
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

def adjust_lr(optimizer, epoch):
    if epoch in [epochs*0.5, epochs*0.75, epochs*0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Change lr:'+str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, target) in train_loader_iter:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss=loss.item(), accuracy=100. * train_acc.item() / train_loader_len)
    
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))



def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total=len(val_loader), desc="Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss=test_loss / len(val_loader.dataset), accuracy=100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))

    return accuracy

if is_train:
    best_val_acc = 0.
    for i in range(epochs):
        train(i + 1)
        temp_acc = val(i + 1)
        if temp_acc > best_val_acc:
            best_val_acc = temp_acc
            torch.save(model.state_dict(), 'resnet18_best.pt')
            print('Best Accuracy: {:.2f}%'.format(best_val_acc))

    print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))

Training Epoch #1: 100%|██████████| 990/990 [02:17<00:00,  7.22it/s, accuracy=3.14, loss=3.75] 


Train Epoch: 1, Loss: 3.901848, Accuracy: 3.14%


Validation Epoch #1: 100%|██████████| 8/8 [00:00<00:00, 12.61it/s, accuracy=4.89, loss=3.72]


Validation Set: Average Loss: 3.7242, Accuracy: 4.89%
Best Accuracy: 4.89%


Training Epoch #2: 100%|██████████| 990/990 [01:57<00:00,  8.43it/s, accuracy=5.26, loss=3.62] 


Train Epoch: 2, Loss: 3.763737, Accuracy: 5.26%


Validation Epoch #2: 100%|██████████| 8/8 [00:00<00:00, 12.69it/s, accuracy=10.7, loss=3.52] 


Validation Set: Average Loss: 3.5248, Accuracy: 10.67%
Best Accuracy: 10.67%


Training Epoch #3: 100%|██████████| 990/990 [02:59<00:00,  5.53it/s, accuracy=7.83, loss=3.54] 


Train Epoch: 3, Loss: 3.611340, Accuracy: 7.83%


Validation Epoch #3: 100%|██████████| 8/8 [00:00<00:00,  8.39it/s, accuracy=13.6, loss=3.37] 


Validation Set: Average Loss: 3.3712, Accuracy: 13.56%
Best Accuracy: 13.56%


Training Epoch #4: 100%|██████████| 990/990 [03:22<00:00,  4.90it/s, accuracy=10.4, loss=3.57] 


Train Epoch: 4, Loss: 3.485261, Accuracy: 10.39%


Validation Epoch #4: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s, accuracy=14.9, loss=3.18] 


Validation Set: Average Loss: 3.1847, Accuracy: 14.89%
Best Accuracy: 14.89%


Training Epoch #5: 100%|██████████| 990/990 [03:22<00:00,  4.89it/s, accuracy=12.1, loss=3.7] 


Train Epoch: 5, Loss: 3.390229, Accuracy: 12.15%


Validation Epoch #5: 100%|██████████| 8/8 [00:00<00:00,  8.08it/s, accuracy=20.7, loss=3.08] 


Validation Set: Average Loss: 3.0821, Accuracy: 20.67%
Best Accuracy: 20.67%


Training Epoch #6: 100%|██████████| 990/990 [03:22<00:00,  4.90it/s, accuracy=13.8, loss=3.37]


Train Epoch: 6, Loss: 3.297329, Accuracy: 13.76%


Validation Epoch #6: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s, accuracy=21.6, loss=2.95] 


Validation Set: Average Loss: 2.9524, Accuracy: 21.56%
Best Accuracy: 21.56%


Training Epoch #7: 100%|██████████| 990/990 [02:37<00:00,  6.30it/s, accuracy=15.6, loss=3.33]


Train Epoch: 7, Loss: 3.212652, Accuracy: 15.62%


Validation Epoch #7: 100%|██████████| 8/8 [00:00<00:00, 13.08it/s, accuracy=18.9, loss=2.9]  


Validation Set: Average Loss: 2.8989, Accuracy: 18.89%


Training Epoch #8: 100%|██████████| 990/990 [02:02<00:00,  8.11it/s, accuracy=17.7, loss=3.37]


Train Epoch: 8, Loss: 3.121726, Accuracy: 17.66%


Validation Epoch #8: 100%|██████████| 8/8 [00:00<00:00, 12.56it/s, accuracy=26, loss=2.74]  


Validation Set: Average Loss: 2.7386, Accuracy: 26.00%
Best Accuracy: 26.00%


Training Epoch #9: 100%|██████████| 990/990 [02:05<00:00,  7.90it/s, accuracy=19.4, loss=3.08]


Train Epoch: 9, Loss: 3.035861, Accuracy: 19.39%


Validation Epoch #9: 100%|██████████| 8/8 [00:00<00:00,  8.28it/s, accuracy=26.9, loss=2.59] 


Validation Set: Average Loss: 2.5941, Accuracy: 26.89%
Best Accuracy: 26.89%


Training Epoch #10: 100%|██████████| 990/990 [03:15<00:00,  5.06it/s, accuracy=20.8, loss=3.21]


Train Epoch: 10, Loss: 2.962334, Accuracy: 20.80%


Validation Epoch #10: 100%|██████████| 8/8 [00:00<00:00,  8.16it/s, accuracy=31.6, loss=2.52] 


Validation Set: Average Loss: 2.5236, Accuracy: 31.56%
Best Accuracy: 31.56%


Training Epoch #11: 100%|██████████| 990/990 [02:41<00:00,  6.15it/s, accuracy=22.2, loss=2.71]


Train Epoch: 11, Loss: 2.887143, Accuracy: 22.20%


Validation Epoch #11: 100%|██████████| 8/8 [00:00<00:00, 13.13it/s, accuracy=27.8, loss=2.64] 


Validation Set: Average Loss: 2.6422, Accuracy: 27.78%


Training Epoch #12: 100%|██████████| 990/990 [03:05<00:00,  5.35it/s, accuracy=23.6, loss=2.67]


Train Epoch: 12, Loss: 2.835069, Accuracy: 23.61%


Validation Epoch #12: 100%|██████████| 8/8 [00:01<00:00,  7.99it/s, accuracy=31.3, loss=2.44] 


Validation Set: Average Loss: 2.4353, Accuracy: 31.33%


Training Epoch #13: 100%|██████████| 990/990 [02:29<00:00,  6.64it/s, accuracy=25.1, loss=2.75]


Train Epoch: 13, Loss: 2.771141, Accuracy: 25.07%


Validation Epoch #13: 100%|██████████| 8/8 [00:00<00:00, 13.29it/s, accuracy=32.2, loss=2.37] 


Validation Set: Average Loss: 2.3727, Accuracy: 32.22%
Best Accuracy: 32.22%


Training Epoch #14: 100%|██████████| 990/990 [03:06<00:00,  5.32it/s, accuracy=25.8, loss=2.68]


Train Epoch: 14, Loss: 2.729147, Accuracy: 25.84%


Validation Epoch #14: 100%|██████████| 8/8 [00:00<00:00,  8.10it/s, accuracy=34.2, loss=2.34] 


Validation Set: Average Loss: 2.3397, Accuracy: 34.22%
Best Accuracy: 34.22%
Change lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [03:18<00:00,  4.98it/s, accuracy=32.1, loss=2.3] 


Train Epoch: 15, Loss: 2.457023, Accuracy: 32.14%


Validation Epoch #15: 100%|██████████| 8/8 [00:01<00:00,  7.99it/s, accuracy=43.6, loss=1.94] 


Validation Set: Average Loss: 1.9415, Accuracy: 43.56%
Best Accuracy: 43.56%


Training Epoch #16: 100%|██████████| 990/990 [03:19<00:00,  4.96it/s, accuracy=34.4, loss=2.34]


Train Epoch: 16, Loss: 2.360078, Accuracy: 34.43%


Validation Epoch #16: 100%|██████████| 8/8 [00:00<00:00,  8.18it/s, accuracy=44.4, loss=1.88] 


Validation Set: Average Loss: 1.8753, Accuracy: 44.44%
Best Accuracy: 44.44%


Training Epoch #17: 100%|██████████| 990/990 [02:28<00:00,  6.68it/s, accuracy=35.4, loss=2.29]


Train Epoch: 17, Loss: 2.319121, Accuracy: 35.38%


Validation Epoch #17: 100%|██████████| 8/8 [00:00<00:00, 13.13it/s, accuracy=46.2, loss=1.83] 


Validation Set: Average Loss: 1.8276, Accuracy: 46.22%
Best Accuracy: 46.22%


Training Epoch #18: 100%|██████████| 990/990 [03:05<00:00,  5.33it/s, accuracy=36.2, loss=2.48]


Train Epoch: 18, Loss: 2.282511, Accuracy: 36.16%


Validation Epoch #18: 100%|██████████| 8/8 [00:00<00:00,  8.35it/s, accuracy=46.7, loss=1.82] 


Validation Set: Average Loss: 1.8194, Accuracy: 46.67%
Best Accuracy: 46.67%


Training Epoch #19: 100%|██████████| 990/990 [03:16<00:00,  5.04it/s, accuracy=36.8, loss=2.13]


Train Epoch: 19, Loss: 2.259003, Accuracy: 36.79%


Validation Epoch #19: 100%|██████████| 8/8 [00:00<00:00,  8.28it/s, accuracy=46.4, loss=1.81] 


Validation Set: Average Loss: 1.8087, Accuracy: 46.44%


Training Epoch #20: 100%|██████████| 990/990 [03:02<00:00,  5.43it/s, accuracy=37.5, loss=2.36]


Train Epoch: 20, Loss: 2.234407, Accuracy: 37.50%


Validation Epoch #20: 100%|██████████| 8/8 [00:00<00:00,  8.03it/s, accuracy=46.9, loss=1.75] 


Validation Set: Average Loss: 1.7513, Accuracy: 46.89%
Best Accuracy: 46.89%


Training Epoch #21: 100%|██████████| 990/990 [03:12<00:00,  5.14it/s, accuracy=38.1, loss=2.12]


Train Epoch: 21, Loss: 2.210109, Accuracy: 38.10%


Validation Epoch #21: 100%|██████████| 8/8 [00:00<00:00,  8.30it/s, accuracy=48, loss=1.72]   


Validation Set: Average Loss: 1.7163, Accuracy: 48.00%
Best Accuracy: 48.00%


Training Epoch #22: 100%|██████████| 990/990 [02:56<00:00,  5.62it/s, accuracy=38.6, loss=2.63]


Train Epoch: 22, Loss: 2.191827, Accuracy: 38.58%


Validation Epoch #22: 100%|██████████| 8/8 [00:00<00:00, 10.73it/s, accuracy=45.8, loss=1.74] 


Validation Set: Average Loss: 1.7388, Accuracy: 45.78%


Training Epoch #23: 100%|██████████| 990/990 [02:53<00:00,  5.69it/s, accuracy=39.1, loss=1.89]


Train Epoch: 23, Loss: 2.166717, Accuracy: 39.12%


Validation Epoch #23: 100%|██████████| 8/8 [00:01<00:00,  7.69it/s, accuracy=50.7, loss=1.73] 


Validation Set: Average Loss: 1.7311, Accuracy: 50.67%
Best Accuracy: 50.67%


Training Epoch #24: 100%|██████████| 990/990 [02:22<00:00,  6.93it/s, accuracy=39.8, loss=2.72]


Train Epoch: 24, Loss: 2.141939, Accuracy: 39.84%


Validation Epoch #24: 100%|██████████| 8/8 [00:01<00:00,  7.75it/s, accuracy=49.8, loss=1.67] 


Validation Set: Average Loss: 1.6696, Accuracy: 49.78%


Training Epoch #25: 100%|██████████| 990/990 [02:04<00:00,  7.92it/s, accuracy=40.2, loss=1.87]


Train Epoch: 25, Loss: 2.121580, Accuracy: 40.18%


Validation Epoch #25: 100%|██████████| 8/8 [00:00<00:00, 13.22it/s, accuracy=50, loss=1.67]   


Validation Set: Average Loss: 1.6669, Accuracy: 50.00%


Training Epoch #26: 100%|██████████| 990/990 [02:29<00:00,  6.62it/s, accuracy=40.5, loss=2.16]


Train Epoch: 26, Loss: 2.109546, Accuracy: 40.51%


Validation Epoch #26: 100%|██████████| 8/8 [00:01<00:00,  7.21it/s, accuracy=50, loss=1.67]   


Validation Set: Average Loss: 1.6677, Accuracy: 50.00%


Training Epoch #27: 100%|██████████| 990/990 [02:09<00:00,  7.64it/s, accuracy=41.3, loss=1.82]


Train Epoch: 27, Loss: 2.081334, Accuracy: 41.27%


Validation Epoch #27: 100%|██████████| 8/8 [00:00<00:00, 13.36it/s, accuracy=50.9, loss=1.65] 


Validation Set: Average Loss: 1.6457, Accuracy: 50.89%
Best Accuracy: 50.89%


Training Epoch #28: 100%|██████████| 990/990 [02:12<00:00,  7.48it/s, accuracy=41.5, loss=2.26]


Train Epoch: 28, Loss: 2.076736, Accuracy: 41.46%


Validation Epoch #28: 100%|██████████| 8/8 [00:00<00:00, 12.32it/s, accuracy=52.7, loss=1.57] 


Validation Set: Average Loss: 1.5749, Accuracy: 52.67%
Best Accuracy: 52.67%


Training Epoch #29: 100%|██████████| 990/990 [02:08<00:00,  7.72it/s, accuracy=41.9, loss=1.81]


Train Epoch: 29, Loss: 2.068666, Accuracy: 41.95%


Validation Epoch #29: 100%|██████████| 8/8 [00:00<00:00, 13.56it/s, accuracy=50, loss=1.65]   


Validation Set: Average Loss: 1.6510, Accuracy: 50.00%


Training Epoch #30: 100%|██████████| 990/990 [02:01<00:00,  8.17it/s, accuracy=42.2, loss=2.28]


Train Epoch: 30, Loss: 2.041643, Accuracy: 42.19%


Validation Epoch #30: 100%|██████████| 8/8 [00:00<00:00, 13.27it/s, accuracy=51.3, loss=1.58] 

Validation Set: Average Loss: 1.5810, Accuracy: 51.33%
Final Best Accuracy: 52.67%





# 載入訓練好ResNet18

In [11]:
resnet18_model = models.resnet18(pretrained=False)
resnet18_model.fc = torch.nn.Linear(model.fc.in_features, 50)
resnet18_model.load_state_dict(torch.load("resnet18_best.pt"))
resnet18_model.to(device)
resnet18_model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
test_augmentations = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_dataset = ImageNetDataset(data_file='test.txt', root_dir='', transform=test_augmentations)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 測試ResNet18

In [16]:
def test(model):
    model.eval()
    correct = 0
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total=len(test_loader), desc="Testing")
    
    all_preds = []
    all_targets = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for data, target in test_loader_iter:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            test_loader_iter.set_postfix(accuracy=100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    accuracy = 100. * correct / test_loader_len
    
    # 計算 Precision, Recall 和 F1-score
    precision = 100. * precision_score(all_targets, all_preds, average='macro')
    recall = 100. * recall_score(all_targets, all_preds, average='macro')
    f1 = 100. * f1_score(all_targets, all_preds, average='macro')
    
    # 計算 FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).to(device))
    
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(resnet18_model)
print(f"Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:00<00:00, 12.99it/s, accuracy=51.3]


Accuracy: 51.33%, Precision: 52.68%, Recall: 51.33%, F1 Score: 50.48%, FLOPS: 333585408, Elapsed Time: 0.62 seconds
