In [4]:
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, f1_score
from PIL import Image

# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset path for Kaggle
DATASET_PATH = "../input/dataset-frame/Dataset_frame"  
print(f"Dataset path exists: {os.path.exists(DATASET_PATH)}")

Using device: cuda
Dataset path exists: True


In [5]:
# Second cell - Dataset and Model Classes
class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, split='Train', transform=None):
        """
        Args:
            root_dir: Base directory of the dataset
            split: 'Train', 'Validation', or 'Test'
            transform: Optional transforms
        """
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Verify directory structure
        print(f"Loading {split} data from: {self.root_dir}")
        
        # Load both Real and Fake images
        for label, class_name in enumerate(['Real', 'Fake']):
            class_dir = os.path.join(self.root_dir, class_name)
            if not os.path.exists(class_dir):
                raise RuntimeError(f"Directory not found: {class_dir}")
            
            files = os.listdir(class_dir)
            print(f"Found {len(files)} {class_name} images")
            
            for img_name in files:
                self.images.append(os.path.join(class_dir, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
                
            return image, torch.tensor(label, dtype=torch.float32)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            # Return a default tensor in case of error
            return torch.zeros((3, 224, 224)), torch.tensor(0., dtype=torch.float32)

def get_transforms(is_training=True):
    if is_training:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class DeepfakeDetector(nn.Module):
    def __init__(self):
        super(DeepfakeDetector, self).__init__()
        # Use EfficientNet-B3 with pretrained weights
        self.base_model = models.efficientnet_b3(weights='IMAGENET1K_V1')
        num_features = self.base_model.classifier[1].in_features
        self.base_model.classifier = nn.Identity()
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.6),  # Increased from 0.3 to 0.5
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.4),  # Increased from 0.2 to 0.3
            nn.Linear(512, 1)
        )

    def forward(self, x):
        features = self.base_model(x)
        return self.classifier(features)

In [6]:
def train_model(config):
    # Initialize model and move to GPU
    model = DeepfakeDetector().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)  # Reduce LR every 3 epochs

    # Initialize datasets
    try:
        train_dataset = DeepfakeDataset(
            config['data_dir'], 
            "Train", 
            transform=get_transforms(True)
        )
        val_dataset = DeepfakeDataset(
            config['data_dir'], 
            "Validation", 
            transform=get_transforms(False)
        )
    except Exception as e:
        print(f"Error initializing datasets: {str(e)}")
        return

    # Create data loaders with updated num_workers
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True  # Added for better performance
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True  # Added for better performance
    )

    # Initialize training variables
    best_auc = 0.0
    early_stopping_counter = 0
    
    # Updated GradScaler initialization
    scaler = torch.amp.GradScaler('cuda')

    # Training loop
    for epoch in range(config['num_epochs']):
        # Training phase
        model.train()
        train_losses = []
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]}')
        
        for images, labels in progress_bar:
            images = images.to(device, non_blocking=True)  # Added non_blocking=True
            labels = labels.to(device, non_blocking=True)  # Added non_blocking=True
            
            optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
            
            # Updated autocast implementation
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(images).squeeze()
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_losses.append(loss.item())
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        # Validation phase
        model.eval()
        val_losses = []
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Validation'):
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                # Updated autocast implementation
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(images).squeeze()
                    loss = criterion(outputs, labels)
                
                val_losses.append(loss.item())
                val_preds.extend(torch.sigmoid(outputs).cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        val_auc = roc_auc_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, np.array(val_preds) > 0.5, average='weighted')
        
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {np.mean(train_losses):.4f}')
        print(f'Val Loss: {np.mean(val_losses):.4f}')
        print(f'Val AUC: {val_auc:.4f}')
        print(f'Val F1: {val_f1:.4f}')

        # Update learning rate
        scheduler.step(val_auc)

        # Save best model
        # Save best model and move it to output directory
        if val_auc > best_auc:
            best_auc = val_auc
            save_path = f"/kaggle/working/best_b3_model_epoch{epoch+1}.pth"
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_auc': best_auc,
                'scaler_state_dict': scaler.state_dict()
            }, save_path)
            
            # Move to Kaggle output directory to prevent loss after session ends
            !cp {save_path} /kaggle/working/best_model.pth  # Save final model
            !cp {save_path} /kaggle/output/  # Save to output
            print(f"✅ Model saved at: {save_path} & /kaggle/output/")

        else:
            early_stopping_counter += 1

        # Early stopping
        if early_stopping_counter >= config['early_stopping_patience']:
            print(f'\nEarly stopping triggered after epoch {epoch+1}')
            break


    print(f'\nBest validation AUC: {best_auc:.4f}')
    return model

# Update the configuration with optimized parameters
config = {
    'data_dir': DATASET_PATH,
    'batch_size': 64,  # You can try 64 if memory allows
    'learning_rate': 1e-4,
    'num_epochs': 5,
    'early_stopping_patience': 5
}

In [4]:
torch.backends.cudnn.benchmark = True  
torch.backends.cuda.matmul.allow_tf32 = True  
torch.backends.cudnn.allow_tf32 = True

# Train the model
model = train_model(config)

Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth
100%|██████████| 47.2M/47.2M [00:00<00:00, 212MB/s]


Loading Train data from: ../input/dataset-frame/Dataset_frame/Train
Found 202669 Real images
Found 205239 Fake images
Loading Validation data from: ../input/dataset-frame/Dataset_frame/Validation
Found 24319 Real images
Found 29190 Fake images


Epoch 1/5:   0%|          | 0/6374 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
import shutil 



# Source path (file in the input folder)

source_path = '/kaggle/input/efficientnet_b4_model/pytorch/default/1/best_b3_model_epoch4.pth' 



# Destination path (in the output folder)

destination_path = "/kaggle/working/" 



# Copy the file

shutil.copy(source_path, destination_path) 


'/kaggle/working/best_b3_model_epoch4.pth'

In [14]:
import os

model_path = "/kaggle/working/best_b3_model_epoch4.pth"
print(f"File exists: {os.path.exists(model_path)}")


File exists: True


In [15]:
import torch

# Load checkpoint
model_path = "/kaggle/working/best_b3_model_epoch4.pth"
checkpoint = torch.load(model_path, map_location="cuda")  # Ensure it loads on GPU if available

# Initialize the model
model = DeepfakeDetector().to("cuda")

# Load only the model's weights
model.load_state_dict(checkpoint["model_state_dict"])  # Fixes the error
model.eval()  # Set model to evaluation mode

print("✅ Model successfully loaded!")


  checkpoint = torch.load(model_path, map_location="cuda")  # Ensure it loads on GPU if available


✅ Model successfully loaded!


In [16]:
def test_model(model, data_dir):
    test_dataset = DeepfakeDataset(data_dir, "Test", transform=get_transforms(False))
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

    model.eval()
    test_preds, test_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images = images.to("cuda")
            outputs = model(images).squeeze()
            test_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            test_labels.extend(labels.numpy())

    auc_score = roc_auc_score(test_labels, test_preds)
    print(f"Test AUC-ROC Score: {auc_score:.4f}")

# Run the test
test_model(model, DATASET_PATH)


Loading Test data from: ../input/dataset-frame/Dataset_frame/Test
Found 28624 Real images
Found 27806 Fake images


Testing:   0%|          | 0/1764 [00:00<?, ?it/s]

Test AUC-ROC Score: 0.9783


In [13]:
def test_model(model_path, data_dir):
    test_dataset = DeepfakeDataset(data_dir, "Test", transform=get_transforms(False))
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

    model = DeepfakeDetector().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    test_preds, test_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            outputs = model(images).squeeze()
            test_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            test_labels.extend(labels.numpy())

    auc_score = roc_auc_score(test_labels, test_preds)
    print(f"Test AUC-ROC Score: {auc_score:.4f}")

if __name__ == "__main__":
    test_model("/kaggle/working/best_b3_model_epoch4.pth", DATASET_PATH)


Loading Test data from: ../input/dataset-frame/Dataset_frame/Test
Found 28624 Real images
Found 27806 Fake images


  model.load_state_dict(torch.load(model_path))


RuntimeError: Error(s) in loading state_dict for DeepfakeDetector:
	Missing key(s) in state_dict: "base_model.features.0.0.weight", "base_model.features.0.1.weight", "base_model.features.0.1.bias", "base_model.features.0.1.running_mean", "base_model.features.0.1.running_var", "base_model.features.1.0.block.0.0.weight", "base_model.features.1.0.block.0.1.weight", "base_model.features.1.0.block.0.1.bias", "base_model.features.1.0.block.0.1.running_mean", "base_model.features.1.0.block.0.1.running_var", "base_model.features.1.0.block.1.fc1.weight", "base_model.features.1.0.block.1.fc1.bias", "base_model.features.1.0.block.1.fc2.weight", "base_model.features.1.0.block.1.fc2.bias", "base_model.features.1.0.block.2.0.weight", "base_model.features.1.0.block.2.1.weight", "base_model.features.1.0.block.2.1.bias", "base_model.features.1.0.block.2.1.running_mean", "base_model.features.1.0.block.2.1.running_var", "base_model.features.1.1.block.0.0.weight", "base_model.features.1.1.block.0.1.weight", "base_model.features.1.1.block.0.1.bias", "base_model.features.1.1.block.0.1.running_mean", "base_model.features.1.1.block.0.1.running_var", "base_model.features.1.1.block.1.fc1.weight", "base_model.features.1.1.block.1.fc1.bias", "base_model.features.1.1.block.1.fc2.weight", "base_model.features.1.1.block.1.fc2.bias", "base_model.features.1.1.block.2.0.weight", "base_model.features.1.1.block.2.1.weight", "base_model.features.1.1.block.2.1.bias", "base_model.features.1.1.block.2.1.running_mean", "base_model.features.1.1.block.2.1.running_var", "base_model.features.2.0.block.0.0.weight", "base_model.features.2.0.block.0.1.weight", "base_model.features.2.0.block.0.1.bias", "base_model.features.2.0.block.0.1.running_mean", "base_model.features.2.0.block.0.1.running_var", "base_model.features.2.0.block.1.0.weight", "base_model.features.2.0.block.1.1.weight", "base_model.features.2.0.block.1.1.bias", "base_model.features.2.0.block.1.1.running_mean", "base_model.features.2.0.block.1.1.running_var", "base_model.features.2.0.block.2.fc1.weight", "base_model.features.2.0.block.2.fc1.bias", "base_model.features.2.0.block.2.fc2.weight", "base_model.features.2.0.block.2.fc2.bias", "base_model.features.2.0.block.3.0.weight", "base_model.features.2.0.block.3.1.weight", "base_model.features.2.0.block.3.1.bias", "base_model.features.2.0.block.3.1.running_mean", "base_model.features.2.0.block.3.1.running_var", "base_model.features.2.1.block.0.0.weight", "base_model.features.2.1.block.0.1.weight", "base_model.features.2.1.block.0.1.bias", "base_model.features.2.1.block.0.1.running_mean", "base_model.features.2.1.block.0.1.running_var", "base_model.features.2.1.block.1.0.weight", "base_model.features.2.1.block.1.1.weight", "base_model.features.2.1.block.1.1.bias", "base_model.features.2.1.block.1.1.running_mean", "base_model.features.2.1.block.1.1.running_var", "base_model.features.2.1.block.2.fc1.weight", "base_model.features.2.1.block.2.fc1.bias", "base_model.features.2.1.block.2.fc2.weight", "base_model.features.2.1.block.2.fc2.bias", "base_model.features.2.1.block.3.0.weight", "base_model.features.2.1.block.3.1.weight", "base_model.features.2.1.block.3.1.bias", "base_model.features.2.1.block.3.1.running_mean", "base_model.features.2.1.block.3.1.running_var", "base_model.features.2.2.block.0.0.weight", "base_model.features.2.2.block.0.1.weight", "base_model.features.2.2.block.0.1.bias", "base_model.features.2.2.block.0.1.running_mean", "base_model.features.2.2.block.0.1.running_var", "base_model.features.2.2.block.1.0.weight", "base_model.features.2.2.block.1.1.weight", "base_model.features.2.2.block.1.1.bias", "base_model.features.2.2.block.1.1.running_mean", "base_model.features.2.2.block.1.1.running_var", "base_model.features.2.2.block.2.fc1.weight", "base_model.features.2.2.block.2.fc1.bias", "base_model.features.2.2.block.2.fc2.weight", "base_model.features.2.2.block.2.fc2.bias", "base_model.features.2.2.block.3.0.weight", "base_model.features.2.2.block.3.1.weight", "base_model.features.2.2.block.3.1.bias", "base_model.features.2.2.block.3.1.running_mean", "base_model.features.2.2.block.3.1.running_var", "base_model.features.3.0.block.0.0.weight", "base_model.features.3.0.block.0.1.weight", "base_model.features.3.0.block.0.1.bias", "base_model.features.3.0.block.0.1.running_mean", "base_model.features.3.0.block.0.1.running_var", "base_model.features.3.0.block.1.0.weight", "base_model.features.3.0.block.1.1.weight", "base_model.features.3.0.block.1.1.bias", "base_model.features.3.0.block.1.1.running_mean", "base_model.features.3.0.block.1.1.running_var", "base_model.features.3.0.block.2.fc1.weight", "base_model.features.3.0.block.2.fc1.bias", "base_model.features.3.0.block.2.fc2.weight", "base_model.features.3.0.block.2.fc2.bias", "base_model.features.3.0.block.3.0.weight", "base_model.features.3.0.block.3.1.weight", "base_model.features.3.0.block.3.1.bias", "base_model.features.3.0.block.3.1.running_mean", "base_model.features.3.0.block.3.1.running_var", "base_model.features.3.1.block.0.0.weight", "base_model.features.3.1.block.0.1.weight", "base_model.features.3.1.block.0.1.bias", "base_model.features.3.1.block.0.1.running_mean", "base_model.features.3.1.block.0.1.running_var", "base_model.features.3.1.block.1.0.weight", "base_model.features.3.1.block.1.1.weight", "base_model.features.3.1.block.1.1.bias", "base_model.features.3.1.block.1.1.running_mean", "base_model.features.3.1.block.1.1.running_var", "base_model.features.3.1.block.2.fc1.weight", "base_model.features.3.1.block.2.fc1.bias", "base_model.features.3.1.block.2.fc2.weight", "base_model.features.3.1.block.2.fc2.bias", "base_model.features.3.1.block.3.0.weight", "base_model.features.3.1.block.3.1.weight", "base_model.features.3.1.block.3.1.bias", "base_model.features.3.1.block.3.1.running_mean", "base_model.features.3.1.block.3.1.running_var", "base_model.features.3.2.block.0.0.weight", "base_model.features.3.2.block.0.1.weight", "base_model.features.3.2.block.0.1.bias", "base_model.features.3.2.block.0.1.running_mean", "base_model.features.3.2.block.0.1.running_var", "base_model.features.3.2.block.1.0.weight", "base_model.features.3.2.block.1.1.weight", "base_model.features.3.2.block.1.1.bias", "base_model.features.3.2.block.1.1.running_mean", "base_model.features.3.2.block.1.1.running_var", "base_model.features.3.2.block.2.fc1.weight", "base_model.features.3.2.block.2.fc1.bias", "base_model.features.3.2.block.2.fc2.weight", "base_model.features.3.2.block.2.fc2.bias", "base_model.features.3.2.block.3.0.weight", "base_model.features.3.2.block.3.1.weight", "base_model.features.3.2.block.3.1.bias", "base_model.features.3.2.block.3.1.running_mean", "base_model.features.3.2.block.3.1.running_var", "base_model.features.4.0.block.0.0.weight", "base_model.features.4.0.block.0.1.weight", "base_model.features.4.0.block.0.1.bias", "base_model.features.4.0.block.0.1.running_mean", "base_model.features.4.0.block.0.1.running_var", "base_model.features.4.0.block.1.0.weight", "base_model.features.4.0.block.1.1.weight", "base_model.features.4.0.block.1.1.bias", "base_model.features.4.0.block.1.1.running_mean", "base_model.features.4.0.block.1.1.running_var", "base_model.features.4.0.block.2.fc1.weight", "base_model.features.4.0.block.2.fc1.bias", "base_model.features.4.0.block.2.fc2.weight", "base_model.features.4.0.block.2.fc2.bias", "base_model.features.4.0.block.3.0.weight", "base_model.features.4.0.block.3.1.weight", "base_model.features.4.0.block.3.1.bias", "base_model.features.4.0.block.3.1.running_mean", "base_model.features.4.0.block.3.1.running_var", "base_model.features.4.1.block.0.0.weight", "base_model.features.4.1.block.0.1.weight", "base_model.features.4.1.block.0.1.bias", "base_model.features.4.1.block.0.1.running_mean", "base_model.features.4.1.block.0.1.running_var", "base_model.features.4.1.block.1.0.weight", "base_model.features.4.1.block.1.1.weight", "base_model.features.4.1.block.1.1.bias", "base_model.features.4.1.block.1.1.running_mean", "base_model.features.4.1.block.1.1.running_var", "base_model.features.4.1.block.2.fc1.weight", "base_model.features.4.1.block.2.fc1.bias", "base_model.features.4.1.block.2.fc2.weight", "base_model.features.4.1.block.2.fc2.bias", "base_model.features.4.1.block.3.0.weight", "base_model.features.4.1.block.3.1.weight", "base_model.features.4.1.block.3.1.bias", "base_model.features.4.1.block.3.1.running_mean", "base_model.features.4.1.block.3.1.running_var", "base_model.features.4.2.block.0.0.weight", "base_model.features.4.2.block.0.1.weight", "base_model.features.4.2.block.0.1.bias", "base_model.features.4.2.block.0.1.running_mean", "base_model.features.4.2.block.0.1.running_var", "base_model.features.4.2.block.1.0.weight", "base_model.features.4.2.block.1.1.weight", "base_model.features.4.2.block.1.1.bias", "base_model.features.4.2.block.1.1.running_mean", "base_model.features.4.2.block.1.1.running_var", "base_model.features.4.2.block.2.fc1.weight", "base_model.features.4.2.block.2.fc1.bias", "base_model.features.4.2.block.2.fc2.weight", "base_model.features.4.2.block.2.fc2.bias", "base_model.features.4.2.block.3.0.weight", "base_model.features.4.2.block.3.1.weight", "base_model.features.4.2.block.3.1.bias", "base_model.features.4.2.block.3.1.running_mean", "base_model.features.4.2.block.3.1.running_var", "base_model.features.4.3.block.0.0.weight", "base_model.features.4.3.block.0.1.weight", "base_model.features.4.3.block.0.1.bias", "base_model.features.4.3.block.0.1.running_mean", "base_model.features.4.3.block.0.1.running_var", "base_model.features.4.3.block.1.0.weight", "base_model.features.4.3.block.1.1.weight", "base_model.features.4.3.block.1.1.bias", "base_model.features.4.3.block.1.1.running_mean", "base_model.features.4.3.block.1.1.running_var", "base_model.features.4.3.block.2.fc1.weight", "base_model.features.4.3.block.2.fc1.bias", "base_model.features.4.3.block.2.fc2.weight", "base_model.features.4.3.block.2.fc2.bias", "base_model.features.4.3.block.3.0.weight", "base_model.features.4.3.block.3.1.weight", "base_model.features.4.3.block.3.1.bias", "base_model.features.4.3.block.3.1.running_mean", "base_model.features.4.3.block.3.1.running_var", "base_model.features.4.4.block.0.0.weight", "base_model.features.4.4.block.0.1.weight", "base_model.features.4.4.block.0.1.bias", "base_model.features.4.4.block.0.1.running_mean", "base_model.features.4.4.block.0.1.running_var", "base_model.features.4.4.block.1.0.weight", "base_model.features.4.4.block.1.1.weight", "base_model.features.4.4.block.1.1.bias", "base_model.features.4.4.block.1.1.running_mean", "base_model.features.4.4.block.1.1.running_var", "base_model.features.4.4.block.2.fc1.weight", "base_model.features.4.4.block.2.fc1.bias", "base_model.features.4.4.block.2.fc2.weight", "base_model.features.4.4.block.2.fc2.bias", "base_model.features.4.4.block.3.0.weight", "base_model.features.4.4.block.3.1.weight", "base_model.features.4.4.block.3.1.bias", "base_model.features.4.4.block.3.1.running_mean", "base_model.features.4.4.block.3.1.running_var", "base_model.features.5.0.block.0.0.weight", "base_model.features.5.0.block.0.1.weight", "base_model.features.5.0.block.0.1.bias", "base_model.features.5.0.block.0.1.running_mean", "base_model.features.5.0.block.0.1.running_var", "base_model.features.5.0.block.1.0.weight", "base_model.features.5.0.block.1.1.weight", "base_model.features.5.0.block.1.1.bias", "base_model.features.5.0.block.1.1.running_mean", "base_model.features.5.0.block.1.1.running_var", "base_model.features.5.0.block.2.fc1.weight", "base_model.features.5.0.block.2.fc1.bias", "base_model.features.5.0.block.2.fc2.weight", "base_model.features.5.0.block.2.fc2.bias", "base_model.features.5.0.block.3.0.weight", "base_model.features.5.0.block.3.1.weight", "base_model.features.5.0.block.3.1.bias", "base_model.features.5.0.block.3.1.running_mean", "base_model.features.5.0.block.3.1.running_var", "base_model.features.5.1.block.0.0.weight", "base_model.features.5.1.block.0.1.weight", "base_model.features.5.1.block.0.1.bias", "base_model.features.5.1.block.0.1.running_mean", "base_model.features.5.1.block.0.1.running_var", "base_model.features.5.1.block.1.0.weight", "base_model.features.5.1.block.1.1.weight", "base_model.features.5.1.block.1.1.bias", "base_model.features.5.1.block.1.1.running_mean", "base_model.features.5.1.block.1.1.running_var", "base_model.features.5.1.block.2.fc1.weight", "base_model.features.5.1.block.2.fc1.bias", "base_model.features.5.1.block.2.fc2.weight", "base_model.features.5.1.block.2.fc2.bias", "base_model.features.5.1.block.3.0.weight", "base_model.features.5.1.block.3.1.weight", "base_model.features.5.1.block.3.1.bias", "base_model.features.5.1.block.3.1.running_mean", "base_model.features.5.1.block.3.1.running_var", "base_model.features.5.2.block.0.0.weight", "base_model.features.5.2.block.0.1.weight", "base_model.features.5.2.block.0.1.bias", "base_model.features.5.2.block.0.1.running_mean", "base_model.features.5.2.block.0.1.running_var", "base_model.features.5.2.block.1.0.weight", "base_model.features.5.2.block.1.1.weight", "base_model.features.5.2.block.1.1.bias", "base_model.features.5.2.block.1.1.running_mean", "base_model.features.5.2.block.1.1.running_var", "base_model.features.5.2.block.2.fc1.weight", "base_model.features.5.2.block.2.fc1.bias", "base_model.features.5.2.block.2.fc2.weight", "base_model.features.5.2.block.2.fc2.bias", "base_model.features.5.2.block.3.0.weight", "base_model.features.5.2.block.3.1.weight", "base_model.features.5.2.block.3.1.bias", "base_model.features.5.2.block.3.1.running_mean", "base_model.features.5.2.block.3.1.running_var", "base_model.features.5.3.block.0.0.weight", "base_model.features.5.3.block.0.1.weight", "base_model.features.5.3.block.0.1.bias", "base_model.features.5.3.block.0.1.running_mean", "base_model.features.5.3.block.0.1.running_var", "base_model.features.5.3.block.1.0.weight", "base_model.features.5.3.block.1.1.weight", "base_model.features.5.3.block.1.1.bias", "base_model.features.5.3.block.1.1.running_mean", "base_model.features.5.3.block.1.1.running_var", "base_model.features.5.3.block.2.fc1.weight", "base_model.features.5.3.block.2.fc1.bias", "base_model.features.5.3.block.2.fc2.weight", "base_model.features.5.3.block.2.fc2.bias", "base_model.features.5.3.block.3.0.weight", "base_model.features.5.3.block.3.1.weight", "base_model.features.5.3.block.3.1.bias", "base_model.features.5.3.block.3.1.running_mean", "base_model.features.5.3.block.3.1.running_var", "base_model.features.5.4.block.0.0.weight", "base_model.features.5.4.block.0.1.weight", "base_model.features.5.4.block.0.1.bias", "base_model.features.5.4.block.0.1.running_mean", "base_model.features.5.4.block.0.1.running_var", "base_model.features.5.4.block.1.0.weight", "base_model.features.5.4.block.1.1.weight", "base_model.features.5.4.block.1.1.bias", "base_model.features.5.4.block.1.1.running_mean", "base_model.features.5.4.block.1.1.running_var", "base_model.features.5.4.block.2.fc1.weight", "base_model.features.5.4.block.2.fc1.bias", "base_model.features.5.4.block.2.fc2.weight", "base_model.features.5.4.block.2.fc2.bias", "base_model.features.5.4.block.3.0.weight", "base_model.features.5.4.block.3.1.weight", "base_model.features.5.4.block.3.1.bias", "base_model.features.5.4.block.3.1.running_mean", "base_model.features.5.4.block.3.1.running_var", "base_model.features.6.0.block.0.0.weight", "base_model.features.6.0.block.0.1.weight", "base_model.features.6.0.block.0.1.bias", "base_model.features.6.0.block.0.1.running_mean", "base_model.features.6.0.block.0.1.running_var", "base_model.features.6.0.block.1.0.weight", "base_model.features.6.0.block.1.1.weight", "base_model.features.6.0.block.1.1.bias", "base_model.features.6.0.block.1.1.running_mean", "base_model.features.6.0.block.1.1.running_var", "base_model.features.6.0.block.2.fc1.weight", "base_model.features.6.0.block.2.fc1.bias", "base_model.features.6.0.block.2.fc2.weight", "base_model.features.6.0.block.2.fc2.bias", "base_model.features.6.0.block.3.0.weight", "base_model.features.6.0.block.3.1.weight", "base_model.features.6.0.block.3.1.bias", "base_model.features.6.0.block.3.1.running_mean", "base_model.features.6.0.block.3.1.running_var", "base_model.features.6.1.block.0.0.weight", "base_model.features.6.1.block.0.1.weight", "base_model.features.6.1.block.0.1.bias", "base_model.features.6.1.block.0.1.running_mean", "base_model.features.6.1.block.0.1.running_var", "base_model.features.6.1.block.1.0.weight", "base_model.features.6.1.block.1.1.weight", "base_model.features.6.1.block.1.1.bias", "base_model.features.6.1.block.1.1.running_mean", "base_model.features.6.1.block.1.1.running_var", "base_model.features.6.1.block.2.fc1.weight", "base_model.features.6.1.block.2.fc1.bias", "base_model.features.6.1.block.2.fc2.weight", "base_model.features.6.1.block.2.fc2.bias", "base_model.features.6.1.block.3.0.weight", "base_model.features.6.1.block.3.1.weight", "base_model.features.6.1.block.3.1.bias", "base_model.features.6.1.block.3.1.running_mean", "base_model.features.6.1.block.3.1.running_var", "base_model.features.6.2.block.0.0.weight", "base_model.features.6.2.block.0.1.weight", "base_model.features.6.2.block.0.1.bias", "base_model.features.6.2.block.0.1.running_mean", "base_model.features.6.2.block.0.1.running_var", "base_model.features.6.2.block.1.0.weight", "base_model.features.6.2.block.1.1.weight", "base_model.features.6.2.block.1.1.bias", "base_model.features.6.2.block.1.1.running_mean", "base_model.features.6.2.block.1.1.running_var", "base_model.features.6.2.block.2.fc1.weight", "base_model.features.6.2.block.2.fc1.bias", "base_model.features.6.2.block.2.fc2.weight", "base_model.features.6.2.block.2.fc2.bias", "base_model.features.6.2.block.3.0.weight", "base_model.features.6.2.block.3.1.weight", "base_model.features.6.2.block.3.1.bias", "base_model.features.6.2.block.3.1.running_mean", "base_model.features.6.2.block.3.1.running_var", "base_model.features.6.3.block.0.0.weight", "base_model.features.6.3.block.0.1.weight", "base_model.features.6.3.block.0.1.bias", "base_model.features.6.3.block.0.1.running_mean", "base_model.features.6.3.block.0.1.running_var", "base_model.features.6.3.block.1.0.weight", "base_model.features.6.3.block.1.1.weight", "base_model.features.6.3.block.1.1.bias", "base_model.features.6.3.block.1.1.running_mean", "base_model.features.6.3.block.1.1.running_var", "base_model.features.6.3.block.2.fc1.weight", "base_model.features.6.3.block.2.fc1.bias", "base_model.features.6.3.block.2.fc2.weight", "base_model.features.6.3.block.2.fc2.bias", "base_model.features.6.3.block.3.0.weight", "base_model.features.6.3.block.3.1.weight", "base_model.features.6.3.block.3.1.bias", "base_model.features.6.3.block.3.1.running_mean", "base_model.features.6.3.block.3.1.running_var", "base_model.features.6.4.block.0.0.weight", "base_model.features.6.4.block.0.1.weight", "base_model.features.6.4.block.0.1.bias", "base_model.features.6.4.block.0.1.running_mean", "base_model.features.6.4.block.0.1.running_var", "base_model.features.6.4.block.1.0.weight", "base_model.features.6.4.block.1.1.weight", "base_model.features.6.4.block.1.1.bias", "base_model.features.6.4.block.1.1.running_mean", "base_model.features.6.4.block.1.1.running_var", "base_model.features.6.4.block.2.fc1.weight", "base_model.features.6.4.block.2.fc1.bias", "base_model.features.6.4.block.2.fc2.weight", "base_model.features.6.4.block.2.fc2.bias", "base_model.features.6.4.block.3.0.weight", "base_model.features.6.4.block.3.1.weight", "base_model.features.6.4.block.3.1.bias", "base_model.features.6.4.block.3.1.running_mean", "base_model.features.6.4.block.3.1.running_var", "base_model.features.6.5.block.0.0.weight", "base_model.features.6.5.block.0.1.weight", "base_model.features.6.5.block.0.1.bias", "base_model.features.6.5.block.0.1.running_mean", "base_model.features.6.5.block.0.1.running_var", "base_model.features.6.5.block.1.0.weight", "base_model.features.6.5.block.1.1.weight", "base_model.features.6.5.block.1.1.bias", "base_model.features.6.5.block.1.1.running_mean", "base_model.features.6.5.block.1.1.running_var", "base_model.features.6.5.block.2.fc1.weight", "base_model.features.6.5.block.2.fc1.bias", "base_model.features.6.5.block.2.fc2.weight", "base_model.features.6.5.block.2.fc2.bias", "base_model.features.6.5.block.3.0.weight", "base_model.features.6.5.block.3.1.weight", "base_model.features.6.5.block.3.1.bias", "base_model.features.6.5.block.3.1.running_mean", "base_model.features.6.5.block.3.1.running_var", "base_model.features.7.0.block.0.0.weight", "base_model.features.7.0.block.0.1.weight", "base_model.features.7.0.block.0.1.bias", "base_model.features.7.0.block.0.1.running_mean", "base_model.features.7.0.block.0.1.running_var", "base_model.features.7.0.block.1.0.weight", "base_model.features.7.0.block.1.1.weight", "base_model.features.7.0.block.1.1.bias", "base_model.features.7.0.block.1.1.running_mean", "base_model.features.7.0.block.1.1.running_var", "base_model.features.7.0.block.2.fc1.weight", "base_model.features.7.0.block.2.fc1.bias", "base_model.features.7.0.block.2.fc2.weight", "base_model.features.7.0.block.2.fc2.bias", "base_model.features.7.0.block.3.0.weight", "base_model.features.7.0.block.3.1.weight", "base_model.features.7.0.block.3.1.bias", "base_model.features.7.0.block.3.1.running_mean", "base_model.features.7.0.block.3.1.running_var", "base_model.features.7.1.block.0.0.weight", "base_model.features.7.1.block.0.1.weight", "base_model.features.7.1.block.0.1.bias", "base_model.features.7.1.block.0.1.running_mean", "base_model.features.7.1.block.0.1.running_var", "base_model.features.7.1.block.1.0.weight", "base_model.features.7.1.block.1.1.weight", "base_model.features.7.1.block.1.1.bias", "base_model.features.7.1.block.1.1.running_mean", "base_model.features.7.1.block.1.1.running_var", "base_model.features.7.1.block.2.fc1.weight", "base_model.features.7.1.block.2.fc1.bias", "base_model.features.7.1.block.2.fc2.weight", "base_model.features.7.1.block.2.fc2.bias", "base_model.features.7.1.block.3.0.weight", "base_model.features.7.1.block.3.1.weight", "base_model.features.7.1.block.3.1.bias", "base_model.features.7.1.block.3.1.running_mean", "base_model.features.7.1.block.3.1.running_var", "base_model.features.8.0.weight", "base_model.features.8.1.weight", "base_model.features.8.1.bias", "base_model.features.8.1.running_mean", "base_model.features.8.1.running_var", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "scheduler_state_dict", "best_auc", "scaler_state_dict". 