In [3]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import time
import random
import torch.nn.functional as F

from torch.nn import init
from torch.nn.modules.utils import _pair
import math
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from thop import profile, clever_format
import gc

In [4]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



In [5]:
class AttentionLayer(nn.Module):
    def __init__(self, c_dim, hidden_dim, nof_kernels, out_channel, in_channel):
        super().__init__()
        self.nof_kernels = nof_kernels
        self.out_channel = out_channel
        self.in_channel = in_channel
        self.global_pooling = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())
        self.to_scores = nn.Sequential(nn.Linear(in_channel, hidden_dim),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(hidden_dim, nof_kernels*out_channel*in_channel)
                                      )

    def forward(self, x, temperature=1):
        out = self.global_pooling(x)
        scores = self.to_scores(out)
        scores = scores.reshape(x.shape[0], self.nof_kernels, self.out_channel, self.in_channel)
        return F.softmax(scores / temperature, dim=1)

In [6]:
class DynamicConv2d(nn.Module):
    def __init__(self, nof_kernels, reduce, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv_args = {'stride': stride, 'padding': padding, 'dilation': dilation}
        self.nof_kernels = nof_kernels
        self.kernel_size = _pair(kernel_size)
        # kernels_weights: (nof_kernels, out_channels, in_channels, *self.kernel_size)
        self.kernels_weights = nn.Parameter(torch.Tensor(
            nof_kernels, out_channels, in_channels, *self.kernel_size), requires_grad=True)
        if bias:
            self.kernels_bias = nn.Parameter(torch.Tensor(out_channels), requires_grad=True)
        else:
            self.register_parameter('kernels_bias', None)

        self.attention = AttentionLayer(3, max(8, in_channels // reduce), nof_kernels, out_channels, in_channels)
        self.initialize_parameters()

    def initialize_parameters(self):
        for i_kernel in range(self.nof_kernels):
            init.kaiming_uniform_(self.kernels_weights[i_kernel], a=math.sqrt(5))
        if self.kernels_bias is not None:
            bound = 1 / math.sqrt(self.kernels_weights[0, 0].numel())
            nn.init.uniform_(self.kernels_bias, -bound, bound)

    def forward(self, x, temperature=1):
        # x: (batch_size , in_channels , H , W)
        batch_size = x.shape[0]
        # alphas: (batch_size , nof_kernels, self.out_channels, self.in_channels)
        alphas = self.attention(x, temperature)

        # self.kernels_weights.unsqueeze(0): (1, nof_kernels, out_channels, in_channels, self.kernel_size, self.kernel_size)
        # alphas.view(): (batch_size , nof_kernels, self.out_channels, self.in_channels, 1, 1)
        # agg_weights: (batch_size, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        agg_weights = torch.sum(
            torch.mul(self.kernels_weights.unsqueeze(0), alphas.view(batch_size, self.nof_kernels, self.out_channels, self.in_channels, 1, 1)), dim=1)

        # agg_weights: (batch_size * out_channels , in_channels, kernel_size, kernel_size)
        agg_weights = agg_weights.view(-1, *agg_weights.shape[-3:])


        if self.kernels_bias is not None:
            agg_bias = self.kernels_bias.repeat(batch_size)
        else:
            agg_bias = None

        x_grouped = x.view(1, -1, *x.shape[-2:])  # (1 , batch_size*out_c , H , W)
        #   out: (1 , batch_size*out_C , H' , W')
        out = F.conv2d(x_grouped, agg_weights, agg_bias, groups=batch_size,
                        **self.conv_args)
        # out: (batch_size , out_channels , H' , W')
        out = out.view(batch_size, -1, *out.shape[-2:])

        return out

In [7]:
class DynamicCNN(nn.Module):
    def __init__(self, num_classes=50, nof_kernels=4):
        super(DynamicCNN, self).__init__()
        self.dycnn = nn.ModuleList([
            DynamicConv2d(nof_kernels=nof_kernels, reduce=4, in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, bias=True),
                    DynamicConv2d(nof_kernels=nof_kernels, reduce=4, in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, bias=True),
                      DynamicConv2d(nof_kernels=nof_kernels, reduce=4, in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, bias=True),
                      DynamicConv2d(nof_kernels=nof_kernels, reduce=4, in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)
                     ])
        self.norm =  nn.ModuleList([nn.BatchNorm2d((2**i) * 32) for i in range(4)])
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2, 0)

            
        self.fc = nn.Sequential(
            nn.Linear(256*8*8, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x, temperature=1):
        out = x
        for i in range(4):
            out = self.dycnn[i](out, temperature=temperature)
            out = self.norm[i](out)
            out = self.act(out)
            out = self.pool(out)
            
        out = out.view(out.size()[0], -1)
        out = self.fc(out)
        return out

In [8]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=50):
        super(SimpleCNN, self).__init__()
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1), # [32, 128, 128]
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [32, 64, 64]

            nn.Conv2d(32, 64, 3, 1, 1), # [64, 64, 64]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [64, 32, 32]

            nn.Conv2d(64, 128, 3, 1, 1), # [128, 32, 32]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [128, 16, 16]

            nn.Conv2d(128, 256, 3, 1, 1), # [256, 16, 16]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [256, 8, 8]
        )
        self.fc = nn.Sequential(
            nn.Linear(256*8*8, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x, temperature=1):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        out = self.fc(out)
        return out

In [9]:
from PIL import Image

class ImgDataset(Dataset):
    def __init__(self, x, y=None, transform=None):
        self.x = x
        self.y = y
        if y is not None:
            self.y = torch.LongTensor(y)
        self.transform = transform
    def __len__(self):
        return len(self.x)
    def __getitem__(self, index):
        X = self.x[index]
        if self.transform is not None:
            X = self.transform(X)
        if self.y is not None:
            Y = self.y[index]
            return X, Y
        else:
            return X


In [10]:
def random_channel(images, channel_combinations):
    channel_dict = {i:c for i,c in enumerate(channel_combinations)}
    new_images = []
    n = len(list(channel_dict.keys()))
    channels = []
    for i, image in enumerate(images):
        channel_idx = i % n
        # 修改通道
        if channel_dict[channel_idx] == 'BGR':
            img = image[:, :, :]
        elif channel_dict[channel_idx] == 'GR':
            img = image[:, :, 1:]
        elif channel_dict[channel_idx] == 'BG':
            img = image[:, :, :2]
        elif channel_dict[channel_idx] == 'R':
            img = image[:, :, 2:3]
        elif channel_dict[channel_idx] == 'G':
            img = image[:, :, 1:2]
        elif channel_dict[channel_idx] == 'B':
            img = image[:, :, 0:1]
        else:
            print("error")
        new_images.append(img)
        channels.append(channel_dict[channel_idx])
    return channels, new_images

In [11]:
def create_model(model_name, nof_kernels=4):
    model = None
    if model_name == "DynamicCNN":
        model = DynamicCNN(num_classes=50, nof_kernels=nof_kernels)
    elif model_name == "base_model":
        model = SimpleCNN(num_classes=50)
        
    return model 

In [12]:
def load_img(f):
    shapes = []
    f=open(f)
    lines=f.readlines()
    imgs, lab=[], []
    for i in range(len(lines)):
        fn, label = lines[i].split(' ')
        im1=cv2.imread(fn)

        if im1.shape[2] not in shapes:
            shapes.append(im1.shape[2])

        imgs.append(im1)
        lab.append(int(label))

    
    lab= np.asarray(lab, np.uint8)
    
    return imgs, lab


In [13]:
def resize_input(channels, images, img_size):
    channel_map = {"BGR": torch.tensor([1,1,1]), "BG":torch.tensor([1,1,0]), "GR":torch.tensor([0,1,1]), "B":torch.tensor([1,0,0]), "G":torch.tensor([0,1,0]), "R":torch.tensor([0,0,1])}
    x = np.zeros((len(images), img_size, img_size, 3), dtype=np.uint8)
    c = np.zeros((len(images), 3),  dtype=np.uint8)
    for i, img in enumerate(images):
        # 3通道圖片
        img_3channel = np.zeros((img_size, img_size, 3), dtype=img.dtype)
        
        img = cv2.resize(img, (img_size, img_size))

        if channels[i] == "B":
            img_3channel[:, :, 0] = img 
        elif channels[i] == "G":
            img_3channel[:, :, 1] = img 
        elif channels[i] == "R":
            img_3channel[:, :, 2] = img 
        elif channels[i] == "BGR":
            img_3channel = img
        elif channels[i] == "BG":
            # 合併成三通道圖像
            img_3channel[:, :, :2] = img  # 前兩個通道保持不變
        elif channels[i] == "GR":
            img_3channel[:, :, 1:] = img  # 後兩個通道保持不變
        else:
            print("error")
        x[i, :, :, :] = img_3channel
        c[i, :] = channel_map[channels[i]]
    return c, x

In [14]:
def calc_complexity(model, inputs):
    # 计算FLOPs和参数量
    flops, params = profile(model, inputs=(inputs,))
    flops, params = clever_format([flops, params], "%.3f")
    print(f"FLOPs: {flops}")
    print(f"Params: {params}")

In [15]:
def train(model, train_loader, val_loader, eval_time, num_epoch, n_train, n_val, temperature, lr, device):
    # print("--4--")
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # model = model.to(device)

    loss = nn.CrossEntropyLoss() # 因為是 classification task，所以 loss 使用 CrossEntropyLoss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer 使用 Adam


    for epoch in range(num_epoch):
        print(epoch)
        
        epoch_start_time = time.time()
        train_acc = 0.0
        train_loss = 0.0
        val_acc = 0.0
        val_loss = 0.0

        model.train() # 確保 model 是在 train model (開啟 Dropout 等...)
        for i, data in enumerate(train_loader):
            optimizer.zero_grad() # 用 optimizer 將 model 參數的 gradient 歸零
            
            train_pred = model(data[0].to(device), temperature=temperature) # 利用 model 得到預測的機率分佈 這邊實際上就是去呼叫 model 的 forward 函數
            batch_loss = loss(train_pred, data[1].to(device)) # 計算 loss （注意 prediction 跟 label 必須同時在 CPU 或是 GPU 上）
            batch_loss.backward() # 利用 back propagation 算出每個參數的 gradient
            optimizer.step() # 以 optimizer 用 gradient 更新參數值

            train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
            train_loss += batch_loss.item()

        if epoch % eval_time == 0:
            model.eval()
            with torch.no_grad():
                for i, data in enumerate(val_loader):
                    val_pred = model(data[0].to(device), temperature=temperature)
                    batch_loss = loss(val_pred, data[1].to(device))

                    val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
                    val_loss += batch_loss.item()

                # 將結果 print 出來
                print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
                    (epoch + 1, num_epoch, time.time()-epoch_start_time, \
                     train_acc/n_train, train_loss/n_train, val_acc/n_val, val_loss/n_val))

                print("Train/epoch",  epoch)
                print("Train/acc", train_acc/n_train)
                print("Train/loss", train_loss/n_train)
                print("Val/epoch", epoch)
                print("Val/acc", val_acc/n_val)
                print("Val/loss", val_loss/n_val)


In [16]:
def test(model, test_loader, n_test, temperature, device):
    model.eval()
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # model = model.to(device)
    loss = nn.CrossEntropyLoss() # 因為是 classification task，所以 loss 使用 CrossEntropyLoss
    
    all_labels = []
    all_preds = []
    test_acc = 0.0
    test_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            test_pred = model(data[0].to(device), temperature=temperature)
            batch_loss = loss(test_pred, data[1].to(device))
            test_loss += batch_loss.item()

            _, preds = torch.max(test_pred, 1)
            all_labels.extend(data[1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())


        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='weighted')
        recall = recall_score(all_labels, all_preds, average='weighted')
        f1 = f1_score(all_labels, all_preds, average='weighted')

        print("--------Test result-------------")
        print(f'Accuracy: {accuracy:.4f}')
        print(f'Precision: {precision:.4f}')
        print(f'Recall: {recall:.4f}')
        print(f'F1-score: {f1:.4f}')
        print(f'Loss: {test_loss//n_test:.4f}')
        print("--------------------------------")

In [None]:
# nof_kernel
def experiment_5_1():
    # 超參數
    #############
    eval_time = 1
    num_epoch = 50
    num_classes = 30
    img_size = 144
    input_size = 128
    batch_size = 128
    lr = 0.001
    model_name= 'DynamicCNN'
    temperature = 1
    #############
    set_seed(42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    x, y = load_img('train.txt')
    vx, vy = load_img('val.txt')
    tx, ty = load_img('test.txt')
    print("--1--")
    # 產生隨機組合通道數的圖片
    channel_combinations = ['BGR','GR','BG','R','G','B']
    c, x_new = random_channel(x, channel_combinations)
    vc, vx_new = random_channel(vx, channel_combinations)
    tc, tx_new = random_channel(tx, channel_combinations)
    del x, vx, tx
    print("--2--")
    # 填補成三通道
    c, x_resize = resize_input(c, x_new, img_size=img_size)
    vc, vx_resize = resize_input(vc, vx_new, img_size=img_size)
    tc, tx_resize = resize_input(tc, tx_new, img_size=img_size)
    del x_new, vx_new, tx_new
    print("--3--")
    # 定義transform
    # training 時做 data augmentation
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((img_size, img_size)),  # 縮放
        transforms.RandomRotation(degrees=30),  # 旋轉
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 平移
        transforms.RandomCrop(input_size),  # 隨機裁剪
        transforms.RandomHorizontalFlip(),  # 水平翻轉
        transforms.ToTensor(),  # 轉換為Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化

    ])
    # testing 時不需做 data augmentation
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((img_size, img_size)),  # 縮放
        transforms.CenterCrop(input_size),  # 中心裁剪
        transforms.ToTensor(),  # 轉換為Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化
    ])
    # data loader
    train_set = ImgDataset(x_resize, y, train_transform)
    val_set = ImgDataset(vx_resize, vy, test_transform)
    test_set = ImgDataset(tx_resize, ty, test_transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, worker_init_fn=lambda _: np.random.seed(42), num_workers=4)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
    del x_resize, vx_resize, tx_resize
    
    print("--4--")
    for nof_kernels in [1, 3, 5, 7]:
        print("-----number of kernels = "+str(nof_kernels)+"-----")
        # create model
        model = create_model(model_name, nof_kernels=nof_kernels).to(device)
        # 測FLOPs、params
        # 创建输入张量
        inputs = torch.randn(1, 3, input_size, input_size)
        calc_complexity(model, inputs.to(device))
        
        
        n_train = train_set.__len__()
        n_val = val_set.__len__()
        n_test = test_set.__len__()
        train(model, train_loader, val_loader, eval_time, num_epoch, n_train, n_val, temperature, lr, device)
        test(model, test_loader, n_test, temperature, device)
        
        del model
        gc.collect()
        
  
experiment_5_1() 

In [None]:
# channel combination
def experiment_3_1_2(nof_kernels):
    # 超參數
    #############
    eval_time = 1
    num_epoch = 40
    num_classes = 50
    img_size = 144
    input_size = 128
    batch_size = 128
    lr = 0.001
    temperature = 1
    nof_kernels = nof_kernels
    #############
    set_seed(42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    x, y = load_img('train.txt')
    vx, vy = load_img('val.txt')
    tx, ty = load_img('test.txt')
    
    # 產生隨機組合通道數的圖片
    channel_combinations = [['BGR','GR','BG','R','G','B'], ['R','G','B'],['BG','GR'],['R', 'BG','BGR'], ['R', 'BG'], ['B', 'BG']]
    
    for combination in channel_combinations:
        print("----channel_combinations = ", combination, end="----\n")
        c, x_new = random_channel(x, combination)
        vc, vx_new = random_channel(vx, combination)
        tc, tx_new = random_channel(tx, combination)
        
       
        # 填補成三通道
        c, x_resize = resize_input(c, x_new, img_size=img_size)
        vc, vx_resize = resize_input(vc, vx_new, img_size=img_size)
        tc, tx_resize = resize_input(tc, tx_new, img_size=img_size)
        del x_new, vx_new, tx_new
        
        # 定義transform
        # training 時做 data augmentation
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),  # 縮放
            transforms.RandomRotation(degrees=30),  # 旋轉
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 平移
            transforms.RandomCrop(input_size),  # 隨機裁剪
            transforms.RandomHorizontalFlip(),  # 水平翻轉
            transforms.ToTensor(),  # 轉換為Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化

        ])
        # testing 時不需做 data augmentation
        test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),  # 縮放
            transforms.CenterCrop(input_size),  # 中心裁剪
            transforms.ToTensor(),  # 轉換為Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化
        ])
        # data loader
        train_set = ImgDataset(x_resize, y, train_transform)
        val_set = ImgDataset(vx_resize, vy, test_transform)
        test_set = ImgDataset(tx_resize, ty, test_transform)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, worker_init_fn=lambda _: np.random.seed(42), num_workers=4)
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
        test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
        del x_resize, vx_resize, tx_resize
        
        
        
        for model_name in ['DynamicCNN', 'base_model']:
            # create model
            model = create_model(model_name, nof_kernels=nof_kernels).to(device)
            # 測FLOPs、params
            # 创建输入张量
            inputs = torch.randn(1, 3, input_size, input_size)
            calc_complexity(model, inputs.to(device))
            
            
            n_train = train_set.__len__()
            n_val = val_set.__len__()
            n_test = test_set.__len__()
            train(model, train_loader, val_loader, eval_time, num_epoch, n_train, n_val, temperature, lr, device)
            test(model, test_loader, n_test, temperature, device)
            
            del model
            gc.collect()
        
    
        
        
experiment_3_1_2(nof_kernels=7)   


In [None]:
# channel combination
def experiment_3_1_1(nof_kernels):
    # 超參數
    #############
    eval_time = 1
    num_epoch = 40
    num_classes = 50
    img_size = 144
    input_size = 128
    batch_size = 128
    lr = 0.001
    temperature = 1
    nof_kernels = nof_kernels
    #############
    set_seed(42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    x, y = load_img('train.txt')
    vx, vy = load_img('val.txt')
    tx, ty = load_img('test.txt')
    
    # 產生隨機組合通道數的圖片
    channel_combinations = [['B'], ['G'], ['R'], ['BG'], ['GR'], ['BGR']]
    
    for combination in channel_combinations:
        print("----channel_combinations = ", combination, end="----\n")
        c, x_new = random_channel(x, combination)
        vc, vx_new = random_channel(vx, combination)
        tc, tx_new = random_channel(tx, combination)
        
       
        # 填補成三通道
        c, x_resize = resize_input(c, x_new, img_size=img_size)
        vc, vx_resize = resize_input(vc, vx_new, img_size=img_size)
        tc, tx_resize = resize_input(tc, tx_new, img_size=img_size)
        del x_new, vx_new, tx_new
        
        # 定義transform
        # training 時做 data augmentation
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),  # 縮放
            transforms.RandomRotation(degrees=30),  # 旋轉
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 平移
            transforms.RandomCrop(input_size),  # 隨機裁剪
            transforms.RandomHorizontalFlip(),  # 水平翻轉
            transforms.ToTensor(),  # 轉換為Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化

        ])
        # testing 時不需做 data augmentation
        test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),  # 縮放
            transforms.CenterCrop(input_size),  # 中心裁剪
            transforms.ToTensor(),  # 轉換為Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化
        ])
        # data loader
        train_set = ImgDataset(x_resize, y, train_transform)
        val_set = ImgDataset(vx_resize, vy, test_transform)
        test_set = ImgDataset(tx_resize, ty, test_transform)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, worker_init_fn=lambda _: np.random.seed(42), num_workers=4)
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
        test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
        del x_resize, vx_resize, tx_resize
        
        
        
        for model_name in ['DynamicCNN', 'base_model']:
            # create model
            model = create_model(model_name, nof_kernels=nof_kernels).to(device)
            # 測FLOPs、params
            # 创建输入张量
            inputs = torch.randn(1, 3, input_size, input_size)
            calc_complexity(model, inputs.to(device))
            
            
            n_train = train_set.__len__()
            n_val = val_set.__len__()
            n_test = test_set.__len__()
            train(model, train_loader, val_loader, eval_time, num_epoch, n_train, n_val, temperature, lr, device)
            test(model, test_loader, n_test, temperature, device)
            
            del model
            gc.collect()
        
    
        
        
experiment_3_1_1(nof_kernels=7)   
