In [5]:
from src.flbase.model import Model
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision.models as models

In [11]:
class s01_PhuongModelNH(Model):
    def __init__(self, config):
        super().__init__(config)
        self.return_embedding = config['FedNH_return_embedding']

        self.conv1 = nn.Conv2d(3,32,kernel_size=4,stride=1,padding=0)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32,64,kernel_size=4,stride=1,padding=0)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64,128,kernel_size=4,stride=1,padding=0)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128,128,kernel_size=4,stride=1,padding=0)
        self.bn4 = nn.BatchNorm2d(128)
        
        self.pool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.pool2= nn.MaxPool2d(kernel_size=3, stride=2)
        
        self.fc1 = nn.Linear(6*6*128,512)
        self.fc2 = nn.Linear(512,config['num_classes'])
        
        temp = nn.Linear(512, config['num_classes'], bias=False).state_dict()['weight']
        self.prototype = nn.Parameter(temp)

        self.flatten = nn.Flatten()
        self.relu = nn.ReLU() 
        self.dropout = nn.Dropout(0.5)

        self.scaling = torch.nn.Parameter(torch.tensor([1.0]))   
        
    def forward(self,x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.pool2(x)
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        feature_embedding = x
        feature_embedding_norm = torch.norm(feature_embedding, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        feature_embedding = torch.div(feature_embedding, feature_embedding_norm)

        if not self.prototype.requires_grad:
            normalized_prototype = self.prototype
        else:
            prototype_norm = torch.norm(self.prototype, p=2, dim=1, keepdim=True).clamp(min=1e-12)
            normalized_prototype = torch.div(self.prototype, prototype_norm)

        logits = torch.matmul(feature_embedding, normalized_prototype.T)
        logits = self.scaling * logits

        if self.return_embedding:
            return feature_embedding, logits
        else:
            return logits


In [12]:
class s02_PhuongModelNH(Model):
     def __init__(self, config):
        super().__init__(config)
        self.return_embedding = config['FedNH_return_embedding']

        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        self.fc1 = nn.Linear(128 * 4 * 4, 512)  # 128 channels, 4x4 feature map after conv4
        self.fc2 = nn.Linear(512, config['num_classes'])
        
        temp = nn.Linear(512, config['num_classes'], bias=False).state_dict()['weight']
        self.prototype = nn.Parameter(temp)

        self.flatten = nn.Flatten()
        self.relu = nn.ReLU() 
        self.dropout = nn.Dropout(0.5)

        self.scaling = torch.nn.Parameter(torch.tensor([1.0]))
        
     def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        feature_embedding = x
        feature_embedding_norm = torch.norm(feature_embedding, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        feature_embedding = torch.div(feature_embedding, feature_embedding_norm)

        if not self.prototype.requires_grad:
            normalized_prototype = self.prototype
        else:
            prototype_norm = torch.norm(self.prototype, p=2, dim=1, keepdim=True).clamp(min=1e-12)
            normalized_prototype = torch.div(self.prototype, prototype_norm)

        logits = torch.matmul(feature_embedding, normalized_prototype.T)
        logits = self.scaling * logits

        if self.return_embedding:
            return feature_embedding, logits
        else:
            return logits


MODEL 3

In [14]:
class PhuongModelNH(Model):
    def __init__(self, config):
        super().__init__(config)
        self.return_embedding = config['FedNH_return_embedding']

        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        # Update kích thước của lớp fully connected fc1
        self.fc1 = nn.Linear(128 * 4 * 4, 512)  # 128 channels, 4x4 feature map after conv4
        
        self.fc2 = nn.Linear(512, config['num_classes'])
        
        temp = nn.Linear(512, config['num_classes'], bias=False).state_dict()['weight']
        self.prototype = nn.Parameter(temp)

        self.flatten = nn.Flatten()
        self.relu = nn.ReLU() 
        self.dropout = nn.Dropout(0.5)

        self.scaling = torch.nn.Parameter(torch.tensor([1.0]))
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        print('After conv1:',x.size())
        x = self.relu(self.bn2(self.conv2(x)))
        print('After conv2:',x.size())
        x = self.relu(self.bn3(self.conv3(x)))
        print('After conv3:',x.size())
        x = self.relu(self.bn4(self.conv4(x)))
        print('After conv4:',x.size())
        x = self.flatten(x)
        print('After flatten:',x.size())
        x = self.relu(self.fc1(x))
        print('After fc1:',x.size())
        x = self.dropout(x)
        print('After dropout:',x.size())
        x = self.fc2(x)
        print('After fc2:',x.size())
        
        feature_embedding = x
        feature_embedding_norm = torch.norm(feature_embedding, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        feature_embedding = torch.div(feature_embedding, feature_embedding_norm)

        if not self.prototype.requires_grad:
            normalized_prototype = self.prototype
        else:
            prototype_norm = torch.norm(self.prototype, p=2, dim=1, keepdim=True).clamp(min=1e-12)
            normalized_prototype = torch.div(self.prototype, prototype_norm)

        logits = torch.matmul(feature_embedding, normalized_prototype.T)
        logits = self.scaling * logits

        if self.return_embedding:
            return feature_embedding, logits
        else:
            return logits


In [19]:
#transform
from torchvision.transforms import transforms
from PIL import Image

image_size = (224,224)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Đọc ảnh từ file
image_path = 'D:/Comvis 2024/FedNH/data/tumorMRI/Training/glioma/augmented_image_0_Tr-gl_0032.jpg'
image = Image.open(image_path).convert('RGB')
image = transform(image)

# Thêm batch dimension
image = image.unsqueeze(0)

# Định nghĩa cấu hình
config = {
    'FedNH_return_embedding': True,
    'num_classes': 1
}
model = PhuongModelNH(config)
model.eval()

# Đưa ảnh qua mô hình
with torch.no_grad():
    output = model(image)

if config['FedNH_return_embedding']:
    feature_embedding, logits = output
    print('Feature Embedding:', feature_embedding)
    print('Logits:', logits)
else:
    logits = output
    print('Logits:', logits)


After conv1: torch.Size([1, 32, 112, 112])
After conv2: torch.Size([1, 64, 56, 56])
After conv3: torch.Size([1, 128, 28, 28])
After conv4: torch.Size([1, 128, 14, 14])
After flatten: torch.Size([1, 25088])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x25088 and 2048x512)