In [1]:
import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

# 解壓 images.zip
def unzip_dataset(zip_path, extract_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print(f"Extracted {zip_path} to {extract_path}")

# 假設 images.zip 在當前目錄
zip_path = "images.zip"
extract_path = "dataset"
if not os.path.exists(extract_path):
    unzip_dataset(zip_path, extract_path)

# 自定義數據集類
class MiniImageNetDataset(Dataset):
    def __init__(self, txt_file, root_dir, channels, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.channels = channels  # 控制輸入通道：RGB, RG, R, G, B 等
        self.data = []
        self.labels = []
        
        # 讀取 txt 文件
        with open(txt_file, 'r') as f:
            for line in f:
                img_path, label = line.strip().split()
                self.data.append(img_path)
                self.labels.append(int(label))
    
    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.data[idx])
        image = Image.open(img_path).convert("RGB")  # 默認加載為 RGB
        
        # 根據 channels 選擇通道
        img_array = np.array(image)        
        image = Image.fromarray(img_array)
        
        if self.transform:
            image = self.transform(image)
        
        label = self.labels[idx]
        return image, label

def get_transform(channels):
    if channels == "R":
        return transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[0:1, :, :]),  
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
    elif channels == "G":
        return transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[1:2, :, :]),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
    elif channels == "B":
        return transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[2:3, :, :]), 
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
    elif channels == "RG":
        return transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[:2, :, :]),
            transforms.Normalize(mean=[0.485, 0.456], std=[0.229, 0.224])
        ])
    else:  # RGB
        return transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])


## SimpleCNN

In [2]:
# Simple CNN
import torch.nn.functional as F
class SimpleCNN(nn.Module):
    def __init__(self, in_channels, num_classes=10, input_size=84):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        
        # 計算池化後的空間維度
        spatial_size = input_size // 4  # 經過兩次池化，每次減半
        self.fc1 = nn.Linear(32 * spatial_size * spatial_size, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# 函數：將圖像轉換為指定通道數量
def convert_channels(images, target_channels):
    if target_channels == 3:
        return images
    elif target_channels == 2:
        return images[:, :2, :, :]
    elif target_channels == 1:
        return images[:, :1, :, :]
    else:
        raise ValueError(f"Unsupported target channels: {target_channels}")

In [None]:
# train/test
def train_model(model, train_loader, val_loader, target_channels, num_epochs=10, device='cuda'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            # 根據 target_channels 轉換圖像通道數量
            images = convert_channels(images, target_channels)
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total

        # 驗證階段
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = convert_channels(images, target_channels)
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * correct / total

        print(f'Epoch [{epoch+1}/{num_epochs}] | '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    return model

# 測試函數
def test_model(model, test_loader, target_channels, device='cuda'):
    model = model.to(device)
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for images, labels in test_loader:
            images = convert_channels(images, target_channels)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss = test_loss / len(test_loader)
    test_acc = 100 * correct / total
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    return test_acc

#### RGB

In [None]:
# 創建數據集
train_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)
val_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)
test_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)

train_rgb_loader = DataLoader(train_dataset_rgb, batch_size=64, shuffle=True)
val_rgb_loader = DataLoader(val_dataset_rgb, batch_size=64, shuffle=True)
test_rgb_loader = DataLoader(test_dataset_rgb, batch_size=64, shuffle=True)

for images, labels in train_rgb_loader:
    print("Image shape:", images.shape) 
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 50
num_epochs = 10
input_size = 84  # 根據實際圖像大小設置

# 訓練和測試 3 通道 (RGB)
print("Training and Testing SimpleCNN with 3 channels (RGB)...")
model_rgb = SimpleCNN(in_channels=3, num_classes=num_classes, input_size=input_size)
model_rgb = train_model(model_rgb, train_rgb_loader, val_rgb_loader, target_channels=3, num_epochs=num_epochs, device=device)
test_acc_rgb = test_model(model_rgb, test_rgb_loader, target_channels=3, device=device)

print("\nSummary of Test Accuracies:")
print(f"SimpleCNN (3 channels - RGB): {test_acc_rgb:.2f}%")

In [None]:
from torchsummary import summary
# 模型參數量
summary(model_rgb, (3, 84, 84))  # 假設輸入是 RGB 圖像

# 如果需要計算 FLOPS，可以使用 ptflops
try:
    from ptflops import get_model_complexity_info
    flops, params = get_model_complexity_info(model_rgb, (3, 84, 84), as_strings=True, print_per_layer_stat=False)
    print(f"Computational complexity: {flops}")
    print(f"Number of parameters: {params}")
except ImportError:
    print("Please install ptflops to compute FLOPS: pip install ptflops")

In [None]:
torch.save(model_rgb.state_dict(), 'model_rgb_state_dict.pth')
del train_dataset_rgb
del train_rgb_loader
del val_dataset_rgb
del val_rgb_loader
del test_dataset_rgb
del test_rgb_loader
del model_rgb

#### RG

In [None]:
# 創建數據集
train_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)
val_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)
test_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)

train_rg_loader = DataLoader(train_dataset_rg, batch_size=64, shuffle=True)
val_rg_loader = DataLoader(val_dataset_rg, batch_size=64, shuffle=True)
test_rg_loader = DataLoader(test_dataset_rg, batch_size=64, shuffle=True)

for images, labels in train_rg_loader:
    print("Image shape:", images.shape) 
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 50
num_epochs = 10
input_size = 84  # 根據實際圖像大小設置

# 訓練和測試
print("Training and Testing SimpleCNN with  channels (RG)...")
model_rg = SimpleCNN(in_channels=2, num_classes=num_classes, input_size=input_size)
model_rg = train_model(model_rg, train_rg_loader, val_rg_loader, target_channels=2, num_epochs=num_epochs, device=device)
test_acc_rg = test_model(model_rg, test_rg_loader, target_channels=2, device=device)

print("\nSummary of Test Accuracies:")
print(f"SimpleCNN (3 channels - RGB): {test_acc_rg:.2f}%")

In [None]:
from torchsummary import summary
# 模型參數量
summary(model_rg, (2, 84, 84)) 

# 如果需要計算 FLOPS，可以使用 ptflops
try:
    from ptflops import get_model_complexity_info
    flops, params = get_model_complexity_info(model_rg, (2, 84, 84), as_strings=True, print_per_layer_stat=False)
    print(f"Computational complexity: {flops}")
    print(f"Number of parameters: {params}")
except ImportError:
    print("Please install ptflops to compute FLOPS: pip install ptflops")

In [None]:
torch.save(model_rg.state_dict(), 'model_rg_state_dict.pth')
del train_dataset_rg
del train_rg_loader
del val_dataset_rg
del val_rg_loader
del test_dataset_rg
del test_rg_loader
del model_rg

#### R

In [None]:
# 創建數據集
train_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)
val_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)
test_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)

train_r_loader = DataLoader(train_dataset_r, batch_size=64, shuffle=True)
val_r_loader = DataLoader(val_dataset_r, batch_size=64, shuffle=True)
test_r_loader = DataLoader(test_dataset_r, batch_size=64, shuffle=True)

for images, labels in train_r_loader:
    print("Image shape:", images.shape) 
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 50
num_epochs = 10
input_size = 84  # 根據實際圖像大小設置

# 訓練和測試
print("Training and Testing SimpleCNN with  channels (R)...")
model_r = SimpleCNN(in_channels=1, num_classes=num_classes, input_size=input_size)
model_r = train_model(model_r, train_r_loader, val_r_loader, target_channels=1, num_epochs=num_epochs, device=device)
test_acc_r = test_model(model_r, test_r_loader, target_channels=1, device=device)

print("\nSummary of Test Accuracies:")
print(f"SimpleCNN (3 channels - RGB): {test_acc_r:.2f}%")

In [None]:
from torchsummary import summary
# 模型參數量
summary(model_r, (1, 84, 84)) 

# 如果需要計算 FLOPS，可以使用 ptflops
try:
    from ptflops import get_model_complexity_info
    flops, params = get_model_complexity_info(model_r, (1, 84, 84), as_strings=True, print_per_layer_stat=False)
    print(f"Computational complexity: {flops}")
    print(f"Number of parameters: {params}")
except ImportError:
    print("Please install ptflops to compute FLOPS: pip install ptflops")
torch.save(model_r.state_dict(), 'model_r_state_dict.pth')

In [None]:
del train_dataset_r
del train_r_loader
del val_dataset_r
del val_r_loader
del test_dataset_r
del test_r_loader
del model_r

## DynamicConv

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicConv(nn.Module):
    def __init__(self, max_in_channels, out_channels, kernel_size=3, hidden_dim=64):
        super().__init__()
        self.max_in_channels = max_in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        # 動態 kernel 生成器：輸入通道數 (one-hot)，輸出一組 kernel
        self.kernel_generator = nn.Sequential(
            nn.Linear(max_in_channels, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_channels * max_in_channels * kernel_size * kernel_size)
        )

    def forward(self, x):
        b, in_c, h, w = x.shape

        # 建立 one-hot 向量，代表目前的輸入通道數
        channel_mask = torch.zeros(self.max_in_channels, device=x.device)
        channel_mask[:in_c] = 1

        # 產生 kernel，reshape 成 conv2d 欲用格式
        kernels = self.kernel_generator(channel_mask)
        kernels = kernels.view(self.out_channels, self.max_in_channels, self.kernel_size, self.kernel_size)
        
        # 只取前 in_c 個通道
        kernels = kernels[:, :in_c, :, :]

        # 用 F.conv2d 做卷積
        out = F.conv2d(x, kernels, bias=None, padding=self.kernel_size//2)
        return out
class DynamicCNN(nn.Module):
    def __init__(self, max_in_channels=3, num_classes=100):
        super().__init__()
        self.dynamic_conv = DynamicConv(max_in_channels=max_in_channels, out_channels=32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.dynamic_conv(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.global_pool(x).view(x.size(0), -1)
        x = self.fc(x)
        return x


In [None]:
import random
class DynamicTransformWrapper:
    def __init__(self):
        self.channel_options = ["RGB", "RG", "R"]
    
    def __call__(self, image):
        channels = random.choice(self.channel_options)
        transform = get_transform(channels)
        return transform(image)

In [None]:
with open("dataset/val.txt", 'r') as f:
    lines = f.readlines()
    
random.shuffle(lines)

channel_options = ["RGB","RG","R"]
n = len(lines)
n_per_channel = n // len(channel_options)
channel_lines = {ch: [] for ch in channel_options}
    
# Assign lines to channels
for i, line in enumerate(lines):
    if i < n_per_channel:
        channel_lines["RGB"].append(line)
    elif i < 2 * n_per_channel:
        channel_lines["RG"].append(line)
    else:
        channel_lines["R"].append(line)
    
# Write to separate text files
for channel in channel_options:
    output_file = os.path.join("dataset",f"val_{channel.lower()}.txt")
    with open(output_file, 'w') as f:
        f.writelines(channel_lines[channel])
    print(f"Generated {output_file} with {len(channel_lines[channel])} images.")

In [None]:
# 動態模型的數據集（RGB、RG、R）
train_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train_rgb.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)
train_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train_rg.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)
train_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "train_r.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)

# Create separate DataLoaders
train_loader_rgb = DataLoader(train_dataset_rgb, batch_size=64, shuffle=True)
train_loader_rg = DataLoader(train_dataset_rg, batch_size=64, shuffle=True)
train_loader_r = DataLoader(train_dataset_r, batch_size=64, shuffle=True)

val_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val_rgb.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)
val_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val_rg.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)
val_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "val_r.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)

# Create separate DataLoaders
val_loader_rgb = DataLoader(val_dataset_rgb, batch_size=64, shuffle=True)
val_loader_rg = DataLoader(val_dataset_rg, batch_size=64, shuffle=True)
val_loader_r = DataLoader(val_dataset_r, batch_size=64, shuffle=True)


# 測試數據集（分為 RGB、RG、R 三組）
test_dataset_rgb = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("RGB"),
    channels="RGB"
)
test_loader_rgb = DataLoader(test_dataset_rgb, batch_size=64, shuffle=False)

test_dataset_rg = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("RG"),
    channels="RG"
)
test_loader_rg = DataLoader(test_dataset_rg, batch_size=64, shuffle=False)

test_dataset_r = MiniImageNetDataset(
    txt_file=os.path.join(extract_path, "test.txt"),
    root_dir=extract_path,
    transform=get_transform("R"),
    channels="R"
)
test_loader_r = DataLoader(test_dataset_r, batch_size=64, shuffle=False)

In [None]:
def train_model(model, train_loaders, val_loaders, criterion, optimizer, num_epochs, device):
    best_val_acc = 0.0
    epochs_no_improve = 0
    patience = 5  # 假設早停耐心值為 5

    for epoch in range(num_epochs):
        # 訓練階段
        model.train()
        running_loss = 0.0
        train_preds = []
        train_labels = []

        # 迭代每個訓練 DataLoader（RGB, RG, R）
        for train_loader in train_loaders:
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)

                # 前向傳播
                outputs = model(images)
                loss = criterion(outputs, labels)

                # 反向傳播和優化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_preds.extend(predicted.cpu().numpy())
                train_labels.extend(labels.cpu().numpy())

        train_accuracy = accuracy_score(train_labels, train_preds)
        avg_train_loss = running_loss / sum(len(loader) for loader in train_loaders)

        # 驗證階段
        model.eval()
        all_preds = []
        all_labels = []
        val_loss = 0.0
        with torch.no_grad():
            # 迭代每個驗證 DataLoader（RGB, RG, R）
            for val_loader in val_loaders:
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
        
        val_accuracy = accuracy_score(all_labels, all_preds)
        avg_val_loss = val_loss / sum(len(loader) for loader in val_loaders)
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy * 100:.2f}%")
        print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy * 100:.2f}%")
        
        # 早停機制
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_model_dynamic.pth")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("早停觸發！")
                break
        
        model.train()
    return model

from sklearn.metrics import accuracy_score, classification_report
def test_model(model, test_loader, channel_combination, device):
    model.eval()
    print(f"\nTesting with channel combination: {channel_combination}")
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Accuracy for {channel_combination}: {accuracy * 100:.2f}%")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds))
    return accuracy

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 50
num_epochs = 10

# Define channel combinations to test
channel_combinations = ["RGB", "RG", "R"]

# Initialize model, criterion, and optimizer
model = DynamicCNN(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_loaders = [train_loader_rgb, train_loader_rg, train_loader_r]
val_loaders = [val_loader_rgb, val_loader_rg, val_loader_r]
model_dynamic = train_model(model, train_loaders, val_loaders, criterion, optimizer, num_epochs, device)
torch.save(model_dynamic.state_dict(), 'model_state_dict.pth')

In [None]:
from torchsummary import summary
# 模型參數量
summary(model, (2, 84, 84)) 

# 如果需要計算 FLOPS，可以使用 ptflops
try:
    from ptflops import get_model_complexity_info
    flops, params = get_model_complexity_info(model, (3, 84, 84), as_strings=True, print_per_layer_stat=False)
    print(f"Computational complexity: {flops}")
    print(f"Number of parameters: {params}")
except ImportError:
    print("Please install ptflops to compute FLOPS: pip install ptflops")

In [None]:
# Testing Phase: Test on each channel combination
print("\n=== Testing Phase ===")
results = {}
for channel_combination in channel_combinations:
    print(f"\nTesting on {channel_combination} data...")
    
    # Create test dataset for the current channel combination
    test_dataset = MiniImageNetDataset(
        txt_file=os.path.join(extract_path, "test.txt"),
        root_dir=extract_path,
        transform=get_transform(channel_combination),
        channels=channel_combination
    )
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # Test the model on this channel combination
    accuracy = test_model(model, test_loader, channel_combination, device)
    results[channel_combination] = accuracy

# Print summary
print("\nSummary of Test Accuracies:")
for channel_combination, accuracy in results.items():
    print(f"{channel_combination}: {accuracy * 100:.2f}%")