In [12]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import numpy as np
import logging
from datetime import datetime

In [13]:
# Model Definition
class FairXRayClassifier(nn.Module):
    def __init__(self, num_races=6):
        super().__init__()
        resnet = models.resnet50(pretrained=True)
        self.feature_dim = 2048

        # Remove final FC layer
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Disease branch
        self.disease_encoder = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
        )

        self.disease_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),  # Single neuron for binary classification (logits)
        )

        self.disease_decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, self.feature_dim)
        )

        # Race branch
        self.race_encoder = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
        )

        self.race_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_races),
        )

        self.race_decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, self.feature_dim)
        )
        
        # Freeze decoder weights
        for param in self.disease_decoder.parameters():
            param.requires_grad = False
        for param in self.race_decoder.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = self.encoder(x)
        features = self.adaptive_pool(features)
        features = features.view(features.size(0), -1)

        # Disease predictions
        disease_features = self.disease_encoder(features)
        disease_pred = self.disease_head(disease_features)  # Outputs logits

        # Race predictions
        race_features = self.race_encoder(features)
        race_logits = self.race_head(race_features)

        # Decoded Features
        disease_decoded = self.disease_decoder(disease_features)
        race_decoded = self.race_decoder(race_features)

        return disease_pred, race_logits, disease_decoded, race_decoded, features


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FairXRayClassifier().to(device)



In [15]:
checkpoint = torch.load('/workspace/test/checkpoints/best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

FairXRayClassifier(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): 

In [16]:
class ValidationDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)  # Changed from parquet to csv
        self.transform = transform if transform else self.get_default_transforms()
    
    @staticmethod
    def get_default_transforms():
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['Path']
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        label = float(self.data.iloc[idx]['label'])
        fitzpatrick = int(self.data.iloc[idx]['fitzpatrick'])
        
        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.float),
            'fitzpatrick': torch.tensor(fitzpatrick, dtype=torch.long),
            'img_path': img_path
        }

In [17]:
val_csv = 'val_dataset.csv'
val_dataset = ValidationDataset(val_csv)
val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

In [26]:
def inference(model, val_loader, device):
    all_results = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            img_paths = batch['img_path']
            race_labels = batch['fitzpatrick']
            
            # Get predictions
            disease_pred, _, _, _, _ = model(images)
            disease_probs = torch.sigmoid(disease_pred)  # Convert logits to probabilities
            disease_pred_binary = (disease_probs > 0.5).float()
            
            # Calculate accuracy
            correct += (disease_pred_binary.squeeze() == labels).sum().item()
            total += labels.size(0)
            
            # Store results
            for i in range(images.size(0)):
                result = {
                    'Path': img_paths[i],
                    'label': labels[i].item(),
                    'fitzpatrick': race_labels[i].item(),
                    'predicton_probability': disease_probs[i].item(),
                    'prediction': disease_pred_binary[i].item()
                }
                all_results.append(result)
            
            logging.info(f"Batch {batch_idx}: Accuracy {(correct/total)*100:.2f}%")
    
    accuracy = 100 * correct / total
    logging.info(f"Overall Accuracy: {accuracy:.2f}%")
    
    return pd.DataFrame(all_results), accuracy

In [27]:
result = inference(model, val_loader, device)

In [28]:
result[0].to_csv('result.csv',index=False)