In [2]:
import torch
import pandas as pd
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
import os
from tqdm.auto import tqdm
import numpy as np
import timm
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
metadata_path = "C:/Users/sabad/OneDrive/Desktop/APS360/HAM10K Dataset/HAM10K/HAM10000_metadata.csv"

# Load metadata
df = pd.read_csv(metadata_path)

# Drop samples with missing values in key metadata fields
df = df.dropna(subset=['age', 'sex', 'localization'])

# Normalize age
scaler = MinMaxScaler()
df['age_norm'] = scaler.fit_transform(df[['age']])

# One-hot encode categorical variables
df = pd.get_dummies(df, columns=['sex', 'localization'], drop_first=False)

# Define metadata feature columns
meta_features = ['age_norm'] + [col for col in df.columns if col.startswith('sex_') or col.startswith('localization_')]
df['meta'] = df[meta_features].values.tolist()

# Encode labels
label_map = {label: i for i, label in enumerate(df['dx'].unique())}
df['label'] = df['dx'].map(label_map)

img_dirs = [
    "C:/Users/sabad/OneDrive/Desktop/APS360/HAM10K Dataset/HAM10K/HAM10000_images_part_1",
    "C:/Users/sabad/OneDrive/Desktop/APS360/HAM10K Dataset/HAM10K/HAM10000_images_part_2"
]

def find_image_path(image_id):
    for dir_path in img_dirs:
        path = os.path.join(dir_path, f"{image_id}.jpg")
        if os.path.exists(path):
            return path
    return None

df['image_path'] = df['image_id'].apply(find_image_path)

unique_ids = df['lesion_id'].unique()
train_ids, temp_ids = train_test_split(unique_ids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

train_df = df[df['lesion_id'].isin(train_ids)]
val_df = df[df['lesion_id'].isin(val_ids)]
test_df = df[df['lesion_id'].isin(test_ids)]

In [4]:
from torch.utils.data import Dataset
from PIL import Image

class HAM10000FusionDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load and transform image
        image = Image.open(row['image_path']).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Load metadata
        meta = torch.tensor(row['meta'], dtype=torch.float32)

        # Load label
        label = torch.tensor(row['label'], dtype=torch.long)

        return image, meta, label

In [5]:
from PIL import Image
from torch.utils.data import Dataset

class HAM10000Dataset(Dataset):
    def __init__(self, dataframe, img_dirs, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dirs = img_dirs  # List of directories
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_name = row['image_id'] + ".jpg"

        # Look for the image in both directories
        for dir_path in self.img_dirs:
            image_path = os.path.join(dir_path, image_name)
            if os.path.exists(image_path):
                break
        else:
            raise FileNotFoundError(f"{image_name} not found in provided directories.")

        image = Image.open(image_path).convert('RGB')
        label = row['label']

        if self.transform:
            image = self.transform(image)

        return image, label

In [27]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# train_dataset = HAM10000Dataset(train_df, img_dirs, transform=train_transform)
# val_dataset = HAM10000Dataset(val_df, img_dirs, transform=val_transform)
# test_dataset = HAM10000Dataset(test_df, img_dirs, transform=val_transform)

train_dataset = HAM10000FusionDataset(train_df, transform=train_transform)
val_dataset = HAM10000FusionDataset(val_df, transform=val_transform)
test_dataset = HAM10000FusionDataset(test_df, transform=val_transform)

label_counts = train_df['label'].value_counts().sort_index().values
class_weights = 1. / label_counts
class_weights = class_weights / class_weights.sum() * len(class_weights)  # Normalize

label_to_weight = {label: class_weights[i] for i, label in enumerate(sorted(train_df['label'].unique()))}
sample_weights = train_df['label'].map(label_to_weight).values

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)


train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    sampler=sampler
)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

meta_input_dim = len(meta_features)
print(meta_features)
print("Metadata input dimension:", meta_input_dim)


['age_norm', 'sex_female', 'sex_male', 'sex_unknown', 'localization_abdomen', 'localization_acral', 'localization_back', 'localization_chest', 'localization_ear', 'localization_face', 'localization_foot', 'localization_genital', 'localization_hand', 'localization_lower extremity', 'localization_neck', 'localization_scalp', 'localization_trunk', 'localization_unknown', 'localization_upper extremity']
Metadata input dimension: 19


In [15]:
from collections import Counter

# Count how many samples of each class appear after sampling
label_counter = Counter()

for images, meta, labels in train_loader:
    labels = labels.cpu().numpy()
    label_counter.update(labels)

# Print counts
for label, count in sorted(label_counter.items()):
    print(f"Class {label}: {count} samples")


Class 0: 992 samples
Class 1: 1038 samples
Class 2: 946 samples
Class 3: 1016 samples
Class 4: 958 samples
Class 5: 1018 samples
Class 6: 989 samples


**ResNet50**

In [10]:
def get_resnet50_model(num_classes):
    model = models.resnet50(pretrained=True)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, num_classes)
    )
    return model

**EfficientNet B0**

In [7]:
class EfficientNetFinetune(nn.Module):
    def __init__(self, num_classes=7):
        super().__init__()
        self.backbone = models.efficientnet_b2(weights='EfficientNet_B2_Weights.DEFAULT')
        in_features = self.backbone.classifier[1].in_features

        # Replace classifier
        self.backbone.classifier = nn.Sequential(
            # nn.Dropout(p=0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

**Fusion (CNN + MLP)**

In [8]:
class FusionModel(nn.Module):
    def __init__(self, image_model_name='efficientnet_b0', meta_input_dim=19, num_classes=7):
        super(FusionModel, self).__init__()

        # Image branch (pretrained EfficientNet)
        self.image_model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT')
        in_features = self.image_model.classifier[1].in_features
        self.image_model.classifier = nn.Identity()  # Remove the original classifier

        # Metadata branch (MLP)
        self.meta_net = nn.Sequential(
            nn.Linear(meta_input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

        # Fusion classifier
        self.classifier = nn.Sequential(
            nn.Linear(in_features + 32, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, metadata):
        img_feat = self.image_model(image)           # [B, in_features]
        meta_feat = self.meta_net(metadata.float())  # [B, 32]
        x = torch.cat([img_feat, meta_feat], dim=1)  # [B, in_features + 32]
        out = self.classifier(x)
        return out

**Generic Training Loop**

In [68]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    model = model.to(device)
    best_val_acc = 0.0
    best_model_wts = model.state_dict()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        # ---------- Training ----------
        model.train()
        train_loss = 0.0
        train_correct = 0

        for images, labels in tqdm(train_loader, desc="Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_correct += (outputs.argmax(1) == labels).sum().item()

        train_acc = train_correct / len(train_loader.dataset)

        # ---------- Validation ----------
        model.eval()
        val_loss = 0.0
        val_correct = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_acc = val_correct / len(val_loader.dataset)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            print("Best model updated")

    model.load_state_dict(best_model_wts)
    return model

**Fusion Training Loop**

In [7]:
import copy

def train_fusion_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        # ---------- TRAIN ----------
        model.train()
        train_loss = 0.0
        train_correct = 0

        for images, metadata, labels in tqdm(train_loader, desc="Training"):
            images = images.to(device)
            metadata = metadata.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, metadata)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            train_correct += (outputs.argmax(dim=1) == labels).sum().item()

        train_loss /= len(train_loader.dataset)
        train_acc = train_correct / len(train_loader.dataset)

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        val_correct = 0

        with torch.no_grad():
            for images, metadata, labels in tqdm(val_loader, desc="Validation"):
                images = images.to(device)
                metadata = metadata.to(device)
                labels = labels.to(device)

                outputs = model(images, metadata)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)

        # ---------- Log ----------
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

        # ---------- Save best model ----------
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            print("✅ Best model updated!")

    # Load best weights before returning
    model.load_state_dict(best_model_wts)
    return model


**Two Phase Fusion Training Loop**

In [53]:
def freeze_all_but_meta_and_classifier(model):
    # Freeze image model
    for param in model.image_model.parameters():
        param.requires_grad = False

    # Unfreeze meta_net and classifier
    for param in model.meta_net.parameters():
        param.requires_grad = True
    for param in model.classifier.parameters():
        param.requires_grad = True

def unfreeze_last_block(model):
    for param in model.image_model.features[-1].parameters():
        param.requires_grad = True

In [9]:
def train_two_phase_fusion(model, train_loader, val_loader, device, phase1_epochs=5, phase2_epochs=5, wd=1e-4):
    criterion = nn.CrossEntropyLoss()

    # -------- Phase 1: Freeze CNN, train only MLP + classifier --------
    print("\nPhase 1: Training metadata branch and fusion classifier")
    freeze_all_but_meta_and_classifier(model)
    optimizer1 = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=wd)
    model = train_fusion_model(model, train_loader, val_loader, criterion, optimizer1, device, num_epochs=phase1_epochs)

    # -------- Phase 2: Unfreeze last CNN block --------
    print("\nPhase 2: Fine-tune last EfficientNet block with rest of model")
    unfreeze_last_block(model)
    optimizer2 = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, weight_decay=wd)
    model = train_fusion_model(model, train_loader, val_loader, criterion, optimizer2, device, num_epochs=phase2_epochs)

    return model

**Two Phase Training Loop**

In [12]:
def set_parameter_requires_grad(model, feature_extracting=True):
    for param in model.parameters():
        param.requires_grad = not feature_extracting

def unfreeze_last_block(model):
    for name, param in model.named_parameters():
        if "layer4" in name or "fc" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

In [13]:
def run_one_epoch(model, dataloader, criterion, optimizer, device, train=True):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, desc="Training" if train else "Validation"):
        images, labels = images.to(device), labels.to(device)

        if train:
            optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    avg_loss = running_loss / len(dataloader)
    return acc, avg_loss


In [14]:
def train_model_two_phase(model, train_loader, val_loader, criterion, device, num_epochs_phase1=5, num_epochs_phase2=5):
    model = model.to(device)
    best_val_acc = 0.0
    best_model_wts = model.state_dict()

    # ------------------ Phase 1: Classifier only ------------------
    print("\nPhase 1: Training classifier only")
    set_parameter_requires_grad(model, feature_extracting=True)
    for param in model.fc.parameters():
        param.requires_grad = True

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

    for epoch in range(num_epochs_phase1):
        print(f"\nEpoch {epoch+1}/{num_epochs_phase1} (Classifier Only)")
        train_acc, train_loss = run_one_epoch(model, train_loader, criterion, optimizer, device, train=True)
        val_acc, val_loss = run_one_epoch(model, val_loader, criterion, optimizer, device, train=False)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            print("Best model updated")

    # ------------------ Phase 2: Unfreeze last conv block ------------------
    print("\nPhase 2: Unfreezing last conv block")
    unfreeze_last_block(model)

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

    for epoch in range(num_epochs_phase2):
        print(f"\nEpoch {epoch+1}/{num_epochs_phase2} (Last Block + FC)")
        train_acc, train_loss = run_one_epoch(model, train_loader, criterion, optimizer, device, train=True)
        val_acc, val_loss = run_one_epoch(model, val_loader, criterion, optimizer, device, train=False)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            print("Best model updated")

    model.load_state_dict(best_model_wts)
    return model

**Efficient Net Training Two Phase**

In [None]:
def train_model_two_phase_en(model, train_loader, val_loader, criterion, device, num_epochs_phase1=5, num_epochs_phase2=5):
    model = model.to(device)

    # --- Phase 1: Freeze backbone, train only classifier
    print("\nPhase 1: Training classifier only")
    for param in model.backbone.parameters():
        param.requires_grad = False
    for param in model.backbone.classifier.parameters():
        param.requires_grad = True

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4)

    model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs_phase1, device)

    # ---------- Phase 2: Unfreeze only last conv block ----------
    print("\nPhase 2: Fine-tuning classifier + last block only")

    # Unfreeze last block (EfficientNet blocks are in model.backbone.blocks)
    for param in model.backbone.blocks[-1].parameters():
        param.requires_grad = True

    # Also keep classifier unfrozen
    for param in model.backbone.classifier.parameters():
        param.requires_grad = True

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs_phase2, device)

    return model


In [None]:
# train_dataset = HAM10000Dataset(train_df, img_dirs, train_transform )
# val_dataset = HAM10000Dataset(val_df, img_dirs, val_transform)

# # Count how many samples in each class
# class_counts = train_df['label'].value_counts().sort_index().values
# class_weights = 1. / class_counts

# # Map weights to each sample
# sample_weights = train_df['label'].map(lambda x: class_weights[x]).values

# # Create the sampler
# sampler = WeightedRandomSampler(weights=sample_weights,
#                                  num_samples=len(sample_weights),
#                                  replacement=True)

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)

**Device and Criterion**

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
num_classes = 7

In [11]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

Torch version: 2.6.0+cu126
CUDA available: True


**Two Phase ResNet Training**

In [None]:
model = get_resnet50_model(num_classes)

In [30]:
model = train_model_two_phase(model, train_loader, val_loader, criterion, device, num_epochs_phase1=8, num_epochs_phase2=8)


Phase 1: Training classifier only

Epoch 1/8 (Classifier Only)


Training: 100%|██████████| 110/110 [04:05<00:00,  2.23s/it]
Validation: 100%|██████████| 24/24 [01:00<00:00,  2.54s/it]


Train Loss: 1.2886, Train Acc: 0.6355
Val   Loss: 1.2538, Val   Acc: 0.6336
Best model updated

Epoch 2/8 (Classifier Only)


Training: 100%|██████████| 110/110 [02:55<00:00,  1.60s/it]
Validation: 100%|██████████| 24/24 [01:07<00:00,  2.80s/it]


Train Loss: 1.1094, Train Acc: 0.6767
Val   Loss: 1.2168, Val   Acc: 0.6336

Epoch 3/8 (Classifier Only)


Training: 100%|██████████| 110/110 [05:56<00:00,  3.24s/it]
Validation: 100%|██████████| 24/24 [01:06<00:00,  2.78s/it]


Train Loss: 1.0819, Train Acc: 0.6767
Val   Loss: 1.1890, Val   Acc: 0.6336

Epoch 4/8 (Classifier Only)


Training: 100%|██████████| 110/110 [05:18<00:00,  2.89s/it]
Validation: 100%|██████████| 24/24 [01:06<00:00,  2.77s/it]


Train Loss: 1.0563, Train Acc: 0.6767
Val   Loss: 1.1671, Val   Acc: 0.6336

Epoch 5/8 (Classifier Only)


Training: 100%|██████████| 110/110 [05:48<00:00,  3.17s/it]
Validation: 100%|██████████| 24/24 [00:46<00:00,  1.95s/it]


Train Loss: 1.0321, Train Acc: 0.6767
Val   Loss: 1.1409, Val   Acc: 0.6336

Epoch 6/8 (Classifier Only)


Training: 100%|██████████| 110/110 [04:15<00:00,  2.32s/it]
Validation: 100%|██████████| 24/24 [01:05<00:00,  2.74s/it]


Train Loss: 1.0127, Train Acc: 0.6767
Val   Loss: 1.1162, Val   Acc: 0.6356
Best model updated

Epoch 7/8 (Classifier Only)


Training: 100%|██████████| 110/110 [05:17<00:00,  2.89s/it]
Validation: 100%|██████████| 24/24 [00:35<00:00,  1.46s/it]


Train Loss: 0.9909, Train Acc: 0.6770
Val   Loss: 1.0962, Val   Acc: 0.6356

Epoch 8/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:01<00:00,  1.65s/it]
Validation: 100%|██████████| 24/24 [00:35<00:00,  1.46s/it]


Train Loss: 0.9708, Train Acc: 0.6771
Val   Loss: 1.0782, Val   Acc: 0.6356

Phase 2: Unfreezing last conv block

Epoch 1/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:38<00:00,  1.98s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]


Train Loss: 0.7879, Train Acc: 0.7168
Val   Loss: 0.7730, Val   Acc: 0.7467
Best model updated

Epoch 2/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:38<00:00,  1.99s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.6501, Train Acc: 0.7691
Val   Loss: 0.6880, Val   Acc: 0.7573
Best model updated

Epoch 3/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:37<00:00,  1.98s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]


Train Loss: 0.5595, Train Acc: 0.7960
Val   Loss: 0.6375, Val   Acc: 0.7851
Best model updated

Epoch 4/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [04:46<00:00,  2.60s/it]
Validation: 100%|██████████| 24/24 [01:12<00:00,  3.04s/it]


Train Loss: 0.5074, Train Acc: 0.8145
Val   Loss: 0.6318, Val   Acc: 0.7817

Epoch 5/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:36<00:00,  1.97s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.4624, Train Acc: 0.8347
Val   Loss: 0.6225, Val   Acc: 0.7950
Best model updated

Epoch 6/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [04:05<00:00,  2.23s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 0.4291, Train Acc: 0.8411
Val   Loss: 0.6306, Val   Acc: 0.7910

Epoch 7/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:38<00:00,  1.98s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.4038, Train Acc: 0.8522
Val   Loss: 0.6011, Val   Acc: 0.7910

Epoch 8/8 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:37<00:00,  1.98s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]

Train Loss: 0.3748, Train Acc: 0.8645
Val   Loss: 0.5891, Val   Acc: 0.7989
Best model updated





In [32]:
model2 = get_resnet50_model(num_classes)



In [33]:
model2 = train_model_two_phase(model2, train_loader, val_loader, criterion, device, num_epochs_phase1=8, num_epochs_phase2=15)


Phase 1: Training classifier only

Epoch 1/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:04<00:00,  1.68s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 1.5448, Train Acc: 0.5225
Val   Loss: 1.3417, Val   Acc: 0.6336
Best model updated

Epoch 2/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.70s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 1.1309, Train Acc: 0.6767
Val   Loss: 1.2492, Val   Acc: 0.6336

Epoch 3/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.69s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 1.0894, Train Acc: 0.6767
Val   Loss: 1.2130, Val   Acc: 0.6336

Epoch 4/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.70s/it]
Validation: 100%|██████████| 24/24 [00:35<00:00,  1.50s/it]


Train Loss: 1.0593, Train Acc: 0.6767
Val   Loss: 1.1906, Val   Acc: 0.6336

Epoch 5/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.70s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 1.0345, Train Acc: 0.6767
Val   Loss: 1.1601, Val   Acc: 0.6343
Best model updated

Epoch 6/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.69s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 1.0215, Train Acc: 0.6767
Val   Loss: 1.1406, Val   Acc: 0.6369
Best model updated

Epoch 7/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.70s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 1.0048, Train Acc: 0.6765
Val   Loss: 1.1188, Val   Acc: 0.6362

Epoch 8/8 (Classifier Only)


Training: 100%|██████████| 110/110 [03:06<00:00,  1.70s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.50s/it]


Train Loss: 0.9877, Train Acc: 0.6767
Val   Loss: 1.1050, Val   Acc: 0.6396
Best model updated

Phase 2: Unfreezing last conv block

Epoch 1/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.03s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]


Train Loss: 0.7963, Train Acc: 0.7176
Val   Loss: 0.7816, Val   Acc: 0.7361
Best model updated

Epoch 2/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]


Train Loss: 0.6405, Train Acc: 0.7734
Val   Loss: 0.6698, Val   Acc: 0.7626
Best model updated

Epoch 3/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.54s/it]


Train Loss: 0.5551, Train Acc: 0.7971
Val   Loss: 0.6281, Val   Acc: 0.7765
Best model updated

Epoch 4/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.5094, Train Acc: 0.8141
Val   Loss: 0.5860, Val   Acc: 0.7864
Best model updated

Epoch 5/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 0.4666, Train Acc: 0.8317
Val   Loss: 0.5850, Val   Acc: 0.7930
Best model updated

Epoch 6/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.4278, Train Acc: 0.8414
Val   Loss: 0.6012, Val   Acc: 0.7857

Epoch 7/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]


Train Loss: 0.3970, Train Acc: 0.8562
Val   Loss: 0.5771, Val   Acc: 0.7976
Best model updated

Epoch 8/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.50s/it]


Train Loss: 0.3759, Train Acc: 0.8629
Val   Loss: 0.5856, Val   Acc: 0.8075
Best model updated

Epoch 9/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.3582, Train Acc: 0.8696
Val   Loss: 0.5936, Val   Acc: 0.7923

Epoch 10/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 0.3356, Train Acc: 0.8791
Val   Loss: 0.5746, Val   Acc: 0.8102
Best model updated

Epoch 11/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.3169, Train Acc: 0.8846
Val   Loss: 0.6087, Val   Acc: 0.7903

Epoch 12/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.2924, Train Acc: 0.8935
Val   Loss: 0.6206, Val   Acc: 0.8016

Epoch 13/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it]


Train Loss: 0.2838, Train Acc: 0.8968
Val   Loss: 0.5946, Val   Acc: 0.8016

Epoch 14/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:42<00:00,  2.03s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it]


Train Loss: 0.2709, Train Acc: 0.9035
Val   Loss: 0.6299, Val   Acc: 0.8016

Epoch 15/15 (Last Block + FC)


Training: 100%|██████████| 110/110 [03:41<00:00,  2.02s/it]
Validation: 100%|██████████| 24/24 [00:36<00:00,  1.53s/it]

Train Loss: 0.2518, Train Acc: 0.9130
Val   Loss: 0.6311, Val   Acc: 0.8069





**Efficient Net Training**

In [74]:
eff_net = EfficientNetFinetune(num_classes=7)
optimizer = torch.optim.Adam(eff_net.parameters(), lr=1e-4, weight_decay=1e-3)
eff_net = train_model(eff_net, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device)


Epoch 1/10
------------------------------


Training: 100%|██████████| 109/109 [00:58<00:00,  1.87it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]


Train Loss: 126.4991, Train Acc: 0.6109
Val   Loss: 19.1609, Val   Acc: 0.7066
Best model updated

Epoch 2/10
------------------------------


Training: 100%|██████████| 109/109 [00:56<00:00,  1.92it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.20it/s]


Train Loss: 58.9989, Train Acc: 0.8031
Val   Loss: 16.8356, Val   Acc: 0.7341
Best model updated

Epoch 3/10
------------------------------


Training:  17%|█▋        | 19/109 [00:10<00:49,  1.83it/s]


KeyboardInterrupt: 

**Two Phase Efficient Net Training**

In [None]:
eff_net2 = EfficientNetFinetune(num_classes=7)
eff_net2 = train_model_two_phase_en(eff_net2, train_loader, val_loader, criterion, device, num_epochs_phase1=10, num_epochs_phase2=10)


Phase 1: Training classifier only

Epoch 1/5
------------------------------


Training: 100%|██████████| 219/219 [01:26<00:00,  2.52it/s]
Validation: 100%|██████████| 48/48 [00:17<00:00,  2.77it/s]


Train Loss: 251.6243, Train Acc: 0.6607
Val   Loss: 56.5759, Val   Acc: 0.6369
Best model updated

Epoch 2/5
------------------------------


Training: 100%|██████████| 219/219 [01:34<00:00,  2.32it/s]
Validation: 100%|██████████| 48/48 [00:15<00:00,  3.10it/s]


Train Loss: 199.7588, Train Acc: 0.7067
Val   Loss: 51.0701, Val   Acc: 0.6700
Best model updated

Epoch 3/5
------------------------------


Training: 100%|██████████| 219/219 [01:26<00:00,  2.52it/s]
Validation: 100%|██████████| 48/48 [00:15<00:00,  3.08it/s]


Train Loss: 185.8001, Train Acc: 0.7146
Val   Loss: 48.9200, Val   Acc: 0.6726
Best model updated

Epoch 4/5
------------------------------


Training: 100%|██████████| 219/219 [01:26<00:00,  2.53it/s]
Validation: 100%|██████████| 48/48 [00:17<00:00,  2.67it/s]


Train Loss: 176.2818, Train Acc: 0.7248
Val   Loss: 48.1479, Val   Acc: 0.6759
Best model updated

Epoch 5/5
------------------------------


Training: 100%|██████████| 219/219 [01:39<00:00,  2.20it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.64it/s]


Train Loss: 169.8980, Train Acc: 0.7309
Val   Loss: 45.3714, Val   Acc: 0.6984
Best model updated

Phase 2: Fine-tuning classifier + last block only

Epoch 1/10
------------------------------


Training: 100%|██████████| 219/219 [01:43<00:00,  2.11it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.64it/s]


Train Loss: 156.2549, Train Acc: 0.7455
Val   Loss: 36.6331, Val   Acc: 0.7493
Best model updated

Epoch 2/10
------------------------------


Training: 100%|██████████| 219/219 [01:43<00:00,  2.11it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.62it/s]


Train Loss: 142.0548, Train Acc: 0.7656
Val   Loss: 35.0072, Val   Acc: 0.7553
Best model updated

Epoch 3/10
------------------------------


Training: 100%|██████████| 219/219 [01:38<00:00,  2.23it/s]
Validation: 100%|██████████| 48/48 [00:15<00:00,  3.09it/s]


Train Loss: 132.4922, Train Acc: 0.7810
Val   Loss: 32.2194, Val   Acc: 0.7652
Best model updated

Epoch 4/10
------------------------------


Training: 100%|██████████| 219/219 [01:43<00:00,  2.11it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.61it/s]


Train Loss: 124.0202, Train Acc: 0.7918
Val   Loss: 31.6469, Val   Acc: 0.7712
Best model updated

Epoch 5/10
------------------------------


Training: 100%|██████████| 219/219 [01:43<00:00,  2.11it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.61it/s]


Train Loss: 121.0269, Train Acc: 0.7995
Val   Loss: 31.0318, Val   Acc: 0.7758
Best model updated

Epoch 6/10
------------------------------


Training: 100%|██████████| 219/219 [01:44<00:00,  2.10it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.62it/s]


Train Loss: 117.0657, Train Acc: 0.8064
Val   Loss: 30.7574, Val   Acc: 0.7791
Best model updated

Epoch 7/10
------------------------------


Training: 100%|██████████| 219/219 [01:44<00:00,  2.10it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.61it/s]


Train Loss: 113.3911, Train Acc: 0.8094
Val   Loss: 30.9627, Val   Acc: 0.7712

Epoch 8/10
------------------------------


Training: 100%|██████████| 219/219 [01:44<00:00,  2.10it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.60it/s]


Train Loss: 109.8897, Train Acc: 0.8161
Val   Loss: 29.6995, Val   Acc: 0.7804
Best model updated

Epoch 9/10
------------------------------


Training: 100%|██████████| 219/219 [01:44<00:00,  2.09it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.57it/s]


Train Loss: 105.4586, Train Acc: 0.8252
Val   Loss: 29.6436, Val   Acc: 0.7851
Best model updated

Epoch 10/10
------------------------------


Training: 100%|██████████| 219/219 [01:44<00:00,  2.09it/s]
Validation: 100%|██████████| 48/48 [00:18<00:00,  2.61it/s]

Train Loss: 101.9395, Train Acc: 0.8334
Val   Loss: 29.7680, Val   Acc: 0.7877
Best model updated





**Fusion Model Training**

In [54]:
fusion_model = FusionModel(num_classes=7)
freeze_all_but_meta_and_classifier(fusion_model)
unfreeze_last_block(fusion_model)
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=1e-4, weight_decay=1e-5)
fusion_model = train_fusion_model(fusion_model, train_loader, val_loader, criterion, optimizer, device=device, num_epochs=10)


Epoch 1/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.15it/s]


Train Loss: 1.7034, Train Acc: 0.3440
Val   Loss: 1.5678, Val   Acc: 0.4622
✅ Best model updated!

Epoch 2/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.13it/s]


Train Loss: 1.3269, Train Acc: 0.5454
Val   Loss: 1.3239, Val   Acc: 0.5700
✅ Best model updated!

Epoch 3/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.11it/s]


Train Loss: 1.1924, Train Acc: 0.5941
Val   Loss: 1.2602, Val   Acc: 0.5820
✅ Best model updated!

Epoch 4/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]


Train Loss: 1.0815, Train Acc: 0.6343
Val   Loss: 1.1419, Val   Acc: 0.6155
✅ Best model updated!

Epoch 5/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]


Train Loss: 1.0213, Train Acc: 0.6500
Val   Loss: 1.1016, Val   Acc: 0.6209
✅ Best model updated!

Epoch 6/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.15it/s]


Train Loss: 0.9469, Train Acc: 0.6841
Val   Loss: 1.0494, Val   Acc: 0.6430
✅ Best model updated!

Epoch 7/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.11it/s]


Train Loss: 0.9173, Train Acc: 0.6895
Val   Loss: 1.0249, Val   Acc: 0.6370

Epoch 8/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.13it/s]


Train Loss: 0.8773, Train Acc: 0.6977
Val   Loss: 1.0392, Val   Acc: 0.6303

Epoch 9/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:45<00:00,  2.40it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.13it/s]


Train Loss: 0.8526, Train Acc: 0.6984
Val   Loss: 1.0022, Val   Acc: 0.6383

Epoch 10/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:46<00:00,  2.35it/s]
Validation: 100%|██████████| 24/24 [00:12<00:00,  1.91it/s]

Train Loss: 0.8138, Train Acc: 0.7204
Val   Loss: 0.9865, Val   Acc: 0.6484
✅ Best model updated!





In [56]:
fusion_model = FusionModel(num_classes=7)
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=1e-4, weight_decay=1e-4)
fusion_model = train_two_phase_fusion(fusion_model, train_loader, val_loader, device, phase1_epochs=10, phase2_epochs=15, wd=1e-4)


Phase 1: Training metadata branch and fusion classifier

Epoch 1/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:52<00:00,  2.07it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.11it/s]


Train Loss: 1.2234, Train Acc: 0.5497
Val   Loss: 1.0933, Val   Acc: 0.6008
✅ Best model updated!

Epoch 2/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]


Train Loss: 0.9393, Train Acc: 0.6514
Val   Loss: 0.9860, Val   Acc: 0.6343
✅ Best model updated!

Epoch 3/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.09it/s]


Train Loss: 0.8599, Train Acc: 0.6901
Val   Loss: 1.0099, Val   Acc: 0.6142

Epoch 4/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.10it/s]


Train Loss: 0.8019, Train Acc: 0.7000
Val   Loss: 0.8993, Val   Acc: 0.6557
✅ Best model updated!

Epoch 5/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.11it/s]


Train Loss: 0.7816, Train Acc: 0.7122
Val   Loss: 0.9737, Val   Acc: 0.6189

Epoch 6/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.04it/s]


Train Loss: 0.7739, Train Acc: 0.7138
Val   Loss: 0.9045, Val   Acc: 0.6510

Epoch 7/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.7456, Train Acc: 0.7237
Val   Loss: 0.8776, Val   Acc: 0.6638
✅ Best model updated!

Epoch 8/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.09it/s]


Train Loss: 0.7187, Train Acc: 0.7326
Val   Loss: 0.8954, Val   Acc: 0.6571

Epoch 9/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.6926, Train Acc: 0.7414
Val   Loss: 0.8315, Val   Acc: 0.6792
✅ Best model updated!

Epoch 10/10
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.05it/s]


Train Loss: 0.7143, Train Acc: 0.7329
Val   Loss: 0.8794, Val   Acc: 0.6597

Phase 2: Fine-tune last EfficientNet block with rest of model

Epoch 1/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.7110, Train Acc: 0.7322
Val   Loss: 0.8688, Val   Acc: 0.6651
✅ Best model updated!

Epoch 2/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.7351, Train Acc: 0.7288
Val   Loss: 0.8209, Val   Acc: 0.6825
✅ Best model updated!

Epoch 3/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.07it/s]


Train Loss: 0.6953, Train Acc: 0.7427
Val   Loss: 0.8451, Val   Acc: 0.6778

Epoch 4/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.7051, Train Acc: 0.7354
Val   Loss: 0.8423, Val   Acc: 0.6792

Epoch 5/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.09it/s]


Train Loss: 0.6969, Train Acc: 0.7418
Val   Loss: 0.8373, Val   Acc: 0.6825

Epoch 6/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.09it/s]


Train Loss: 0.6826, Train Acc: 0.7456
Val   Loss: 0.8597, Val   Acc: 0.6718

Epoch 7/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.7079, Train Acc: 0.7411
Val   Loss: 0.8442, Val   Acc: 0.6745

Epoch 8/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.10it/s]


Train Loss: 0.6799, Train Acc: 0.7490
Val   Loss: 0.8292, Val   Acc: 0.6772

Epoch 9/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.6901, Train Acc: 0.7466
Val   Loss: 0.8426, Val   Acc: 0.6758

Epoch 10/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Train Loss: 0.6614, Train Acc: 0.7538
Val   Loss: 0.8417, Val   Acc: 0.6765

Epoch 11/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.04it/s]


Train Loss: 0.6928, Train Acc: 0.7436
Val   Loss: 0.8367, Val   Acc: 0.6825

Epoch 12/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.10it/s]


Train Loss: 0.6789, Train Acc: 0.7449
Val   Loss: 0.8235, Val   Acc: 0.6805

Epoch 13/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.07it/s]


Train Loss: 0.6769, Train Acc: 0.7497
Val   Loss: 0.8163, Val   Acc: 0.6845
✅ Best model updated!

Epoch 14/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.08it/s]


Train Loss: 0.6710, Train Acc: 0.7483
Val   Loss: 0.8252, Val   Acc: 0.6832

Epoch 15/15
----------------------------------------


Training: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]

Train Loss: 0.6894, Train Acc: 0.7450
Val   Loss: 0.8110, Val   Acc: 0.6906
✅ Best model updated!





**Image Only Training (EffNet)**

In [32]:
def train_image_only(model, train_loader, val_loader, device, num_epochs=5, lr=1e-4, wd=1e-5):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

    for epoch in range(num_epochs):
        print(f"\n[Image-Only] Epoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        # ----- Train -----
        model.train()
        total_loss, correct = 0, 0

        for images, _, labels in tqdm(train_loader, desc="Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        train_loss = total_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        # ----- Validation -----
        model.eval()
        val_loss, val_correct = 0, 0

        with torch.no_grad():
            for images, _, labels in tqdm(val_loader, desc="Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
        scheduler.step()

    return model


**Freeze EffNet**

In [30]:
def freeze_image_model(model):
    for param in model.image_model.parameters():
        param.requires_grad = False

    for param in model.meta_net.parameters():
        param.requires_grad = True

    for param in model.classifier.parameters():
        param.requires_grad = True

**Fusion Training Loop**

In [33]:
def train_fusion(model, train_loader, val_loader, device, num_epochs=5, lr=1e-4, wd=1e-5):
    freeze_image_model(model)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

    for epoch in range(num_epochs):
        print(f"\n[Fusion Phase] Epoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        # ----- Train -----
        model.train()
        total_loss, correct = 0, 0

        for images, metadata, labels in tqdm(train_loader, desc="Training"):
            images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, metadata)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        train_loss = total_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        # ----- Validation -----
        model.eval()
        val_loss, val_correct = 0, 0

        with torch.no_grad():
            for images, metadata, labels in tqdm(val_loader, desc="Validation"):
                images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)
                outputs = model(images, metadata)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
        scheduler.step()

    return model

**Training CNN First, Then Fusion**

In [None]:
# Step 1: Pretrain EfficientNet on image data
image_model = EfficientNetFinetune(num_classes=7)
image_model = train_image_only(image_model, train_loader, val_loader, device, num_epochs=10, wd=1e-4)

# Step 2: Initialize FusionModel and copy pretrained image model
fusion_model = FusionModel(image_model_name='efficientnet_b0', meta_input_dim=meta_input_dim, num_classes=7)
fusion_model.image_model.load_state_dict(image_model.state_dict())

# Step 3: Train FusionModel (freeze EfficientNet)
fusion_model = train_fusion(fusion_model, train_loader, val_loader, device, num_epochs=10, wd=1e-4)

**Test**

In [9]:
from sklearn.metrics import accuracy_score, f1_score

class FusionModel(nn.Module):
    def __init__(self, image_model_name='efficientnet_b2', meta_input_dim=19, num_classes=7):
        super(FusionModel, self).__init__()
        
        # Image branch (pretrained EfficientNet-B2)
        self.image_model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        
        # Get the correct number of input features for B2
        in_features = 1408  # EfficientNet-B2 has 1408 output features before classifier
        self.image_model.classifier = nn.Identity()  # Remove original classifier
        
        # Metadata branch
        self.meta_net = nn.Sequential(
            nn.Linear(meta_input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        # Fusion classifier
        self.classifier = nn.Sequential(
            nn.Linear(in_features + 32, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, image, metadata):
        img_feat = self.image_model(image)
        meta_feat = self.meta_net(metadata)
        combined = torch.cat([img_feat, meta_feat], dim=1)
        return self.classifier(combined)
    
    def freeze_cnn(self):
        """Freeze all CNN parameters"""
        for param in self.image_model.parameters():
            param.requires_grad = False
    
    def unfreeze_classifier(self):
        """Unfreeze only classifier parameters"""
        for param in self.classifier.parameters():
            param.requires_grad = True
        for param in self.meta_net.parameters():
            param.requires_grad = True

In [None]:
def train_cnn_first(model, train_loader, val_loader, epochs=10, device='cuda'):
    """Phase 1: Train only the CNN part on images"""
    # Create temporary classifier for phase 1 training
    temp_classifier = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),  # Slightly higher dropout for B2
        nn.Linear(1408, model.classifier[-1].out_features)
    )
    model.image_model.classifier = temp_classifier
    
    # Freeze metadata branch and fusion classifier
    for param in model.meta_net.parameters():
        param.requires_grad = False
    for param in model.classifier.parameters():
        param.requires_grad = False
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.image_model.parameters(), lr=1e-4, weight_decay=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    best_model_weights = None
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        all_preds = []
        all_targets = []
        
        for images, _, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            images, targets = images.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model.image_model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
        
        train_loss /= len(train_loader)
        train_acc = accuracy_score(all_targets, all_preds)
        train_f1 = f1_score(all_targets, all_preds, average='weighted')
        
        # Validation
        val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, device, phase='image')
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_weights = model.image_model.state_dict().copy()
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
            }, 'best_model.pth')
            print(f"Saved new best model with val_acc: {val_acc:.4f}")
    
    # Restore best weights and original classifier structure
    model.image_model.load_state_dict(best_model_weights)
    model.image_model.classifier = nn.Identity()  # Restore to original configuration
    return model

def train_fusion_model(model, train_loader, val_loader, epochs=10, device='cuda'):
    """Phase 2: Train fusion model with frozen CNN"""

    cnn_checkpoint = torch.load('best_cnn_model.pth')
    cnn_state_dict = {}
    
    # Filter only CNN weights
    for k, v in cnn_checkpoint['model_state_dict'].items():
        if k.startswith('image_model.'):
            cnn_state_dict[k.replace('image_model.', '')] = v
    
    # Load CNN weights
    model.image_model.load_state_dict(cnn_state_dict)
    
    model.freeze_cnn()
    model.unfreeze_classifier()
    
    model = model.to(device)
    optimizer = torch.optim.Adam([
        {'params': model.meta_net.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=1e-3, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    best_model_weights = None
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        all_preds = []
        all_targets = []
        
        for images, metadata, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            images = images.to(device)
            metadata = metadata.to(device).float()
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images, metadata)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
        
        train_loss /= len(train_loader)
        train_acc = accuracy_score(all_targets, all_preds)
        train_f1 = f1_score(all_targets, all_preds, average='weighted')
        
        # Validation
        val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, device, phase='fusion')
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_weights = model.image_model.state_dict().copy()
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
            }, 'best_model_two.pth')
            print(f"Saved new best model with val_acc: {val_acc:.4f}")
    
    # Restore best weights
    if best_model_weights:
        model.load_state_dict(best_model_weights)
    return model

def validate(model, val_loader, criterion, device, phase='fusion'):
    """Validation function for both phases"""
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in val_loader:
            if phase == 'image':
                images, _, targets = batch
                images, targets = images.to(device), targets.to(device)
                outputs = model.image_model(images)
            else:  # fusion
                images, metadata, targets = batch
                images = images.to(device)
                metadata = metadata.to(device).float()
                targets = targets.to(device)
                outputs = model(images, metadata)
            
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = accuracy_score(all_targets, all_preds)
    val_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    return val_loss, val_acc, val_f1

In [13]:
def test_model(model, test_loader, device='cuda'):
    """Evaluate the model on test set and return metrics"""
    model.eval()
    model.to(device)
    
    test_loss = 0.0
    all_preds = []
    all_targets = []
    all_probabilities = []
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, metadata, targets in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            metadata = metadata.to(device).float()
            targets = targets.to(device)
            
            outputs = model(images, metadata)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            probs = torch.softmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probabilities.extend(probs.cpu().numpy())
    
    # Calculate metrics
    test_loss /= len(test_loader)
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    

    print("\nTest Results:")
    print(f"Loss: {test_loss:.4f}")
    print(f"Accuracy: {test_acc:.4f}")
    print(f"F1 Score: {test_f1:.4f}")
    
    return {
        'loss': test_loss,
        'accuracy': test_acc,
        'f1': test_f1,
        'predictions': all_preds,
        'probabilities': all_probabilities,
        'targets': all_targets,
    }

In [11]:
model = FusionModel(meta_input_dim=19, num_classes=7)
model = train_cnn_first(model, train_loader, val_loader, epochs=10)
model = train_fusion_model(model, train_loader, val_loader, epochs=10)

Epoch 1/10 [Train]: 100%|██████████| 109/109 [00:56<00:00,  1.94it/s]


Epoch 1/10
Train Loss: 1.2166 | Acc: 0.5869 | F1: 0.5802
Val Loss: 0.8372 | Acc: 0.6952 | F1: 0.7222
Saved best model


Epoch 2/10 [Train]: 100%|██████████| 109/109 [00:56<00:00,  1.91it/s]


Epoch 2/10
Train Loss: 0.5612 | Acc: 0.8031 | F1: 0.8014
Val Loss: 0.6746 | Acc: 0.7468 | F1: 0.7672
Saved best model


Epoch 3/10 [Train]: 100%|██████████| 109/109 [00:56<00:00,  1.94it/s]


Epoch 3/10
Train Loss: 0.4076 | Acc: 0.8551 | F1: 0.8541
Val Loss: 0.6368 | Acc: 0.7736 | F1: 0.7907
Saved best model


Epoch 4/10 [Train]: 100%|██████████| 109/109 [00:56<00:00,  1.93it/s]


Epoch 4/10
Train Loss: 0.3227 | Acc: 0.8841 | F1: 0.8836
Val Loss: 0.6849 | Acc: 0.7642 | F1: 0.7818


Epoch 5/10 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.96it/s]


Epoch 5/10
Train Loss: 0.2556 | Acc: 0.9050 | F1: 0.9046
Val Loss: 0.5918 | Acc: 0.8118 | F1: 0.8189
Saved best model


Epoch 6/10 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.95it/s]


Epoch 6/10
Train Loss: 0.2329 | Acc: 0.9155 | F1: 0.9152
Val Loss: 0.7142 | Acc: 0.7636 | F1: 0.7838


Epoch 7/10 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.98it/s]


Epoch 7/10
Train Loss: 0.1982 | Acc: 0.9290 | F1: 0.9288
Val Loss: 0.6680 | Acc: 0.7964 | F1: 0.8062


Epoch 8/10 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.97it/s]


Epoch 8/10
Train Loss: 0.1732 | Acc: 0.9363 | F1: 0.9360
Val Loss: 0.6271 | Acc: 0.8185 | F1: 0.8235
Saved best model


Epoch 9/10 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 9/10
Train Loss: 0.1444 | Acc: 0.9475 | F1: 0.9473
Val Loss: 0.6453 | Acc: 0.8125 | F1: 0.8189


Epoch 10/10 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.99it/s]


Epoch 10/10
Train Loss: 0.1322 | Acc: 0.9523 | F1: 0.9522
Val Loss: 0.7124 | Acc: 0.7830 | F1: 0.7997


Epoch 1/10 [Train]: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]


Epoch 1/10
Train Loss: 0.2844 | Acc: 0.9327 | F1: 0.9326
Val Loss: 0.6111 | Acc: 0.8051 | F1: 0.8161
Saved best model


Epoch 2/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]


Epoch 2/10
Train Loss: 0.1507 | Acc: 0.9530 | F1: 0.9529
Val Loss: 0.6721 | Acc: 0.8098 | F1: 0.8201
Saved best model


Epoch 3/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 3/10
Train Loss: 0.1465 | Acc: 0.9547 | F1: 0.9545
Val Loss: 0.6894 | Acc: 0.8104 | F1: 0.8183
Saved best model


Epoch 4/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.53it/s]


Epoch 4/10
Train Loss: 0.1417 | Acc: 0.9552 | F1: 0.9551
Val Loss: 0.6852 | Acc: 0.8098 | F1: 0.8176


Epoch 5/10 [Train]: 100%|██████████| 109/109 [00:50<00:00,  2.17it/s]


Epoch 5/10
Train Loss: 0.1372 | Acc: 0.9544 | F1: 0.9543
Val Loss: 0.7125 | Acc: 0.8111 | F1: 0.8203
Saved best model


Epoch 6/10 [Train]: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]


Epoch 6/10
Train Loss: 0.1404 | Acc: 0.9517 | F1: 0.9517
Val Loss: 0.7007 | Acc: 0.8044 | F1: 0.8153


Epoch 7/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 7/10
Train Loss: 0.1257 | Acc: 0.9567 | F1: 0.9566
Val Loss: 0.7190 | Acc: 0.8125 | F1: 0.8207
Saved best model


Epoch 8/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 8/10
Train Loss: 0.1265 | Acc: 0.9537 | F1: 0.9538
Val Loss: 0.7154 | Acc: 0.8171 | F1: 0.8246
Saved best model


Epoch 9/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 9/10
Train Loss: 0.1306 | Acc: 0.9527 | F1: 0.9525
Val Loss: 0.7548 | Acc: 0.8084 | F1: 0.8182


Epoch 10/10 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 10/10
Train Loss: 0.1305 | Acc: 0.9567 | F1: 0.9567
Val Loss: 0.6936 | Acc: 0.8192 | F1: 0.8257
Saved best model


In [14]:
results = test_model(model, test_loader)

# You can access individual results:
print(f"Final Test Accuracy: {results['accuracy']:.2%}")
print(f"Final Test F1 Score: {results['f1']:.4f}")

Testing: 100%|██████████| 24/24 [00:15<00:00,  1.51it/s]


Test Results:
Loss: 0.7121
Accuracy: 0.8024
F1 Score: 0.8090
Final Test Accuracy: 80.24%
Final Test F1 Score: 0.8090





In [16]:
from sklearn.metrics import precision_score, recall_score, roc_auc_score
from sklearn.preprocessing import label_binarize
import numpy as np

def test_model(model, test_loader, class_names, device='cuda'):
    """Evaluate model with comprehensive per-class and overall metrics"""
    model.eval()
    model.to(device)
    
    all_preds = []
    all_targets = []
    all_probabilities = []
    
    with torch.no_grad():
        for images, metadata, targets in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            metadata = metadata.to(device).float()
            targets = targets.to(device)
            
            outputs = model(images, metadata)
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probabilities.extend(probs.cpu().numpy())
    
    # Convert to numpy arrays
    y_true = np.array(all_targets)
    y_pred = np.array(all_preds)
    y_probs = np.array(all_probabilities)
    
    # Number of classes
    n_classes = len(class_names)
    
    # Binarize for AUC calculation
    y_true_bin = label_binarize(y_true, classes=range(n_classes))
    
    # Calculate metrics for each class
    metrics = {
        'class_names': class_names,
        'per_class': {},
        'overall': {}
    }
    
    # Per-class metrics
    for i in range(n_classes):
        metrics['per_class'][class_names[i]] = {
            'accuracy': np.mean(y_true[y_true == i] == y_pred[y_true == i]),
            'precision': precision_score(y_true, y_pred, labels=[i], average='micro'),
            'recall': recall_score(y_true, y_pred, labels=[i], average='micro'),
            'f1': f1_score(y_true, y_pred, labels=[i], average='micro'),
            'auc': roc_auc_score(
                y_true_bin[:, i],
                y_probs[:, i],
                multi_class='ovr'
            ) if n_classes > 2 else roc_auc_score(y_true, y_probs[:, 1])
        }
    
    # Overall metrics (averages)
    metrics['overall'] = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision_macro': precision_score(y_true, y_pred, average='macro'),
        'precision_weighted': precision_score(y_true, y_pred, average='weighted'),
        'recall_macro': recall_score(y_true, y_pred, average='macro'),
        'recall_weighted': recall_score(y_true, y_pred, average='weighted'),
        'f1_macro': f1_score(y_true, y_pred, average='macro'),
        'f1_weighted': f1_score(y_true, y_pred, average='weighted'),
        'auc_macro': roc_auc_score(
            y_true_bin,
            y_probs,
            multi_class='ovr',
            average='macro'
        ) if n_classes > 2 else roc_auc_score(y_true, y_probs[:, 1]),
        'auc_weighted': roc_auc_score(
            y_true_bin,
            y_probs,
            multi_class='ovr',
            average='weighted'
        ) if n_classes > 2 else roc_auc_score(y_true, y_probs[:, 1])
    }
    
    # Print results
    print("\n=== PER-CLASS METRICS ===")
    for cls in metrics['per_class']:
        print(f"\nClass: {cls}")
        print(f"Accuracy: {metrics['per_class'][cls]['accuracy']:.4f}")
        print(f"Precision: {metrics['per_class'][cls]['precision']:.4f}")
        print(f"Recall: {metrics['per_class'][cls]['recall']:.4f}")
        print(f"F1: {metrics['per_class'][cls]['f1']:.4f}")
        print(f"AUC: {metrics['per_class'][cls]['auc']:.4f}")
    
    print("\n=== OVERALL METRICS ===")
    print(f"Accuracy: {metrics['overall']['accuracy']:.4f}")
    print(f"\nMacro Averages:")
    print(f"Precision: {metrics['overall']['precision_macro']:.4f}")
    print(f"Recall: {metrics['overall']['recall_macro']:.4f}")
    print(f"F1: {metrics['overall']['f1_macro']:.4f}")
    print(f"AUC: {metrics['overall']['auc_macro']:.4f}")
    
    print(f"\nWeighted Averages:")
    print(f"Precision: {metrics['overall']['precision_weighted']:.4f}")
    print(f"Recall: {metrics['overall']['recall_weighted']:.4f}")
    print(f"F1: {metrics['overall']['f1_weighted']:.4f}")
    print(f"AUC: {metrics['overall']['auc_weighted']:.4f}")
    
    return metrics

In [17]:
# Example usage:
class_names = ['nv', 'mel', 'bkl', 'bcc', 'akiec', 'vasc', 'df']  # Your actual class names
results = test_model(model, test_loader, class_names)

Testing: 100%|██████████| 24/24 [00:07<00:00,  3.04it/s]



=== PER-CLASS METRICS ===

Class: nv
Accuracy: 0.7045
Precision: 0.6327
Recall: 0.7045
F1: 0.6667
AUC: 0.9379

Class: mel
Accuracy: 0.8855
Precision: 0.9312
Recall: 0.8855
F1: 0.9078
AUC: 0.9507

Class: bkl
Accuracy: 0.5556
Precision: 0.3125
Recall: 0.5556
F1: 0.4000
AUC: 0.9891

Class: bcc
Accuracy: 0.6257
Precision: 0.4908
Recall: 0.6257
F1: 0.5501
AUC: 0.8960

Class: akiec
Accuracy: 0.7308
Precision: 0.7600
Recall: 0.7308
F1: 0.7451
AUC: 0.9789

Class: vasc
Accuracy: 0.6211
Precision: 0.8310
Recall: 0.6211
F1: 0.7108
AUC: 0.9734

Class: df
Accuracy: 0.5660
Precision: 0.5769
Recall: 0.5660
F1: 0.5714
AUC: 0.9488

=== OVERALL METRICS ===
Accuracy: 0.8024

Macro Averages:
Precision: 0.6479
Recall: 0.6699
F1: 0.6503
AUC: 0.9535

Weighted Averages:
Precision: 0.8210
Recall: 0.8024
F1: 0.8090
AUC: 0.9451


In [18]:
# Save
torch.save(model.state_dict(), 'fusion_model_weights.pth')

# Load
# model = FusionModel(meta_input_dim=19, num_classes=7)  # Recreate model architecture
# model.load_state_dict(torch.load('fusion_model_weights.pth'))
# model.eval()  # Set to evaluation mode

In [20]:
model2 = FusionModel(meta_input_dim=19, num_classes=7)
model2 = train_cnn_first(model2, train_loader, val_loader, epochs=20)
model2 = train_fusion_model(model2, train_loader, val_loader, epochs=20)
results = test_model(model2, test_loader, class_names)

Epoch 1/20 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.97it/s]


Epoch 1/20
Train Loss: 1.2060 | Acc: 0.5872 | F1: 0.5818
Val Loss: 0.8148 | Acc: 0.6946 | F1: 0.7222
Saved new best model with val_acc: 0.6946


Epoch 2/20 [Train]: 100%|██████████| 109/109 [00:57<00:00,  1.91it/s]


Epoch 2/20
Train Loss: 0.5641 | Acc: 0.8009 | F1: 0.7990
Val Loss: 0.6780 | Acc: 0.7515 | F1: 0.7698
Saved new best model with val_acc: 0.7515


Epoch 3/20 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.98it/s]


Epoch 3/20
Train Loss: 0.3965 | Acc: 0.8560 | F1: 0.8547
Val Loss: 0.6607 | Acc: 0.7528 | F1: 0.7726
Saved new best model with val_acc: 0.7528


Epoch 4/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 4/20
Train Loss: 0.3322 | Acc: 0.8777 | F1: 0.8772
Val Loss: 0.6356 | Acc: 0.7709 | F1: 0.7892
Saved new best model with val_acc: 0.7709


Epoch 5/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.01it/s]


Epoch 5/20
Train Loss: 0.2545 | Acc: 0.9090 | F1: 0.9088
Val Loss: 0.6121 | Acc: 0.7991 | F1: 0.8089
Saved new best model with val_acc: 0.7991


Epoch 6/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 6/20
Train Loss: 0.2233 | Acc: 0.9165 | F1: 0.9162
Val Loss: 0.6198 | Acc: 0.7997 | F1: 0.8077
Saved new best model with val_acc: 0.7997


Epoch 7/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 7/20
Train Loss: 0.1951 | Acc: 0.9300 | F1: 0.9297
Val Loss: 0.6381 | Acc: 0.8058 | F1: 0.8152
Saved new best model with val_acc: 0.8058


Epoch 8/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 8/20
Train Loss: 0.1643 | Acc: 0.9431 | F1: 0.9430
Val Loss: 0.6459 | Acc: 0.7944 | F1: 0.8084


Epoch 9/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 9/20
Train Loss: 0.1491 | Acc: 0.9498 | F1: 0.9496
Val Loss: 0.6621 | Acc: 0.8031 | F1: 0.8157


Epoch 10/20 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.98it/s]


Epoch 10/20
Train Loss: 0.1376 | Acc: 0.9513 | F1: 0.9511
Val Loss: 0.5872 | Acc: 0.8379 | F1: 0.8404
Saved new best model with val_acc: 0.8379


Epoch 11/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.98it/s]


Epoch 11/20
Train Loss: 0.1173 | Acc: 0.9609 | F1: 0.9609
Val Loss: 0.6517 | Acc: 0.8171 | F1: 0.8244


Epoch 12/20 [Train]: 100%|██████████| 109/109 [00:55<00:00,  1.98it/s]


Epoch 12/20
Train Loss: 0.1142 | Acc: 0.9615 | F1: 0.9614
Val Loss: 0.6820 | Acc: 0.8098 | F1: 0.8189


Epoch 13/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.98it/s]


Epoch 13/20
Train Loss: 0.1008 | Acc: 0.9656 | F1: 0.9656
Val Loss: 0.7444 | Acc: 0.7971 | F1: 0.8100


Epoch 14/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.99it/s]


Epoch 14/20
Train Loss: 0.0944 | Acc: 0.9661 | F1: 0.9661
Val Loss: 0.6681 | Acc: 0.8178 | F1: 0.8234


Epoch 15/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.98it/s]


Epoch 15/20
Train Loss: 0.0935 | Acc: 0.9665 | F1: 0.9665
Val Loss: 0.6635 | Acc: 0.8279 | F1: 0.8302


Epoch 16/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.01it/s]


Epoch 16/20
Train Loss: 0.0938 | Acc: 0.9672 | F1: 0.9672
Val Loss: 0.7086 | Acc: 0.8212 | F1: 0.8288


Epoch 17/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.00it/s]


Epoch 17/20
Train Loss: 0.0763 | Acc: 0.9743 | F1: 0.9742
Val Loss: 0.7209 | Acc: 0.8285 | F1: 0.8340


Epoch 18/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  2.01it/s]


Epoch 18/20
Train Loss: 0.0790 | Acc: 0.9717 | F1: 0.9716
Val Loss: 0.6859 | Acc: 0.8326 | F1: 0.8352


Epoch 19/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.99it/s]


Epoch 19/20
Train Loss: 0.0798 | Acc: 0.9730 | F1: 0.9729
Val Loss: 0.7067 | Acc: 0.8399 | F1: 0.8401
Saved new best model with val_acc: 0.8399


Epoch 20/20 [Train]: 100%|██████████| 109/109 [00:54<00:00,  1.99it/s]


Epoch 20/20
Train Loss: 0.0725 | Acc: 0.9756 | F1: 0.9755
Val Loss: 0.6872 | Acc: 0.8392 | F1: 0.8353


Epoch 1/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 1/20
Train Loss: 0.2160 | Acc: 0.9619 | F1: 0.9619
Val Loss: 0.5905 | Acc: 0.8346 | F1: 0.8375
Saved new best model with val_acc: 0.8346


Epoch 2/20 [Train]: 100%|██████████| 109/109 [00:42<00:00,  2.54it/s]


Epoch 2/20
Train Loss: 0.0990 | Acc: 0.9747 | F1: 0.9746
Val Loss: 0.6520 | Acc: 0.8299 | F1: 0.8342


Epoch 3/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 3/20
Train Loss: 0.0907 | Acc: 0.9746 | F1: 0.9745
Val Loss: 0.6523 | Acc: 0.8359 | F1: 0.8368
Saved new best model with val_acc: 0.8359


Epoch 4/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 4/20
Train Loss: 0.0772 | Acc: 0.9747 | F1: 0.9747
Val Loss: 0.6543 | Acc: 0.8433 | F1: 0.8411
Saved new best model with val_acc: 0.8433


Epoch 5/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 5/20
Train Loss: 0.0735 | Acc: 0.9767 | F1: 0.9767
Val Loss: 0.6875 | Acc: 0.8379 | F1: 0.8404


Epoch 6/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 6/20
Train Loss: 0.0655 | Acc: 0.9793 | F1: 0.9793
Val Loss: 0.7074 | Acc: 0.8366 | F1: 0.8386


Epoch 7/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 7/20
Train Loss: 0.0697 | Acc: 0.9783 | F1: 0.9783
Val Loss: 0.7268 | Acc: 0.8346 | F1: 0.8383


Epoch 8/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 8/20
Train Loss: 0.0699 | Acc: 0.9771 | F1: 0.9771
Val Loss: 0.7064 | Acc: 0.8466 | F1: 0.8444
Saved new best model with val_acc: 0.8466


Epoch 9/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]


Epoch 9/20
Train Loss: 0.0674 | Acc: 0.9773 | F1: 0.9773
Val Loss: 0.7223 | Acc: 0.8433 | F1: 0.8432


Epoch 10/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 10/20
Train Loss: 0.0650 | Acc: 0.9780 | F1: 0.9780
Val Loss: 0.7123 | Acc: 0.8500 | F1: 0.8487
Saved new best model with val_acc: 0.8500


Epoch 11/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 11/20
Train Loss: 0.0655 | Acc: 0.9773 | F1: 0.9772
Val Loss: 0.7455 | Acc: 0.8392 | F1: 0.8417


Epoch 12/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]


Epoch 12/20
Train Loss: 0.0591 | Acc: 0.9807 | F1: 0.9807
Val Loss: 0.7328 | Acc: 0.8399 | F1: 0.8401


Epoch 13/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]


Epoch 13/20
Train Loss: 0.0667 | Acc: 0.9771 | F1: 0.9771
Val Loss: 0.7417 | Acc: 0.8446 | F1: 0.8438


Epoch 14/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 14/20
Train Loss: 0.0593 | Acc: 0.9812 | F1: 0.9812
Val Loss: 0.7663 | Acc: 0.8433 | F1: 0.8431


Epoch 15/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 15/20
Train Loss: 0.0591 | Acc: 0.9816 | F1: 0.9816
Val Loss: 0.7825 | Acc: 0.8446 | F1: 0.8421


Epoch 16/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.53it/s]


Epoch 16/20
Train Loss: 0.0634 | Acc: 0.9784 | F1: 0.9784
Val Loss: 0.7952 | Acc: 0.8379 | F1: 0.8407


Epoch 17/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 17/20
Train Loss: 0.0625 | Acc: 0.9797 | F1: 0.9797
Val Loss: 0.7697 | Acc: 0.8339 | F1: 0.8356


Epoch 18/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]


Epoch 18/20
Train Loss: 0.0553 | Acc: 0.9828 | F1: 0.9827
Val Loss: 0.7736 | Acc: 0.8446 | F1: 0.8437


Epoch 19/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]


Epoch 19/20
Train Loss: 0.0554 | Acc: 0.9829 | F1: 0.9829
Val Loss: 0.7651 | Acc: 0.8446 | F1: 0.8436


Epoch 20/20 [Train]: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]


Epoch 20/20
Train Loss: 0.0556 | Acc: 0.9828 | F1: 0.9827
Val Loss: 0.7723 | Acc: 0.8285 | F1: 0.8318


RuntimeError: Error(s) in loading state_dict for FusionModel:
	Missing key(s) in state_dict: "image_model.features.0.0.weight", "image_model.features.0.1.weight", "image_model.features.0.1.bias", "image_model.features.0.1.running_mean", "image_model.features.0.1.running_var", "image_model.features.1.0.block.0.0.weight", "image_model.features.1.0.block.0.1.weight", "image_model.features.1.0.block.0.1.bias", "image_model.features.1.0.block.0.1.running_mean", "image_model.features.1.0.block.0.1.running_var", "image_model.features.1.0.block.1.fc1.weight", "image_model.features.1.0.block.1.fc1.bias", "image_model.features.1.0.block.1.fc2.weight", "image_model.features.1.0.block.1.fc2.bias", "image_model.features.1.0.block.2.0.weight", "image_model.features.1.0.block.2.1.weight", "image_model.features.1.0.block.2.1.bias", "image_model.features.1.0.block.2.1.running_mean", "image_model.features.1.0.block.2.1.running_var", "image_model.features.1.1.block.0.0.weight", "image_model.features.1.1.block.0.1.weight", "image_model.features.1.1.block.0.1.bias", "image_model.features.1.1.block.0.1.running_mean", "image_model.features.1.1.block.0.1.running_var", "image_model.features.1.1.block.1.fc1.weight", "image_model.features.1.1.block.1.fc1.bias", "image_model.features.1.1.block.1.fc2.weight", "image_model.features.1.1.block.1.fc2.bias", "image_model.features.1.1.block.2.0.weight", "image_model.features.1.1.block.2.1.weight", "image_model.features.1.1.block.2.1.bias", "image_model.features.1.1.block.2.1.running_mean", "image_model.features.1.1.block.2.1.running_var", "image_model.features.2.0.block.0.0.weight", "image_model.features.2.0.block.0.1.weight", "image_model.features.2.0.block.0.1.bias", "image_model.features.2.0.block.0.1.running_mean", "image_model.features.2.0.block.0.1.running_var", "image_model.features.2.0.block.1.0.weight", "image_model.features.2.0.block.1.1.weight", "image_model.features.2.0.block.1.1.bias", "image_model.features.2.0.block.1.1.running_mean", "image_model.features.2.0.block.1.1.running_var", "image_model.features.2.0.block.2.fc1.weight", "image_model.features.2.0.block.2.fc1.bias", "image_model.features.2.0.block.2.fc2.weight", "image_model.features.2.0.block.2.fc2.bias", "image_model.features.2.0.block.3.0.weight", "image_model.features.2.0.block.3.1.weight", "image_model.features.2.0.block.3.1.bias", "image_model.features.2.0.block.3.1.running_mean", "image_model.features.2.0.block.3.1.running_var", "image_model.features.2.1.block.0.0.weight", "image_model.features.2.1.block.0.1.weight", "image_model.features.2.1.block.0.1.bias", "image_model.features.2.1.block.0.1.running_mean", "image_model.features.2.1.block.0.1.running_var", "image_model.features.2.1.block.1.0.weight", "image_model.features.2.1.block.1.1.weight", "image_model.features.2.1.block.1.1.bias", "image_model.features.2.1.block.1.1.running_mean", "image_model.features.2.1.block.1.1.running_var", "image_model.features.2.1.block.2.fc1.weight", "image_model.features.2.1.block.2.fc1.bias", "image_model.features.2.1.block.2.fc2.weight", "image_model.features.2.1.block.2.fc2.bias", "image_model.features.2.1.block.3.0.weight", "image_model.features.2.1.block.3.1.weight", "image_model.features.2.1.block.3.1.bias", "image_model.features.2.1.block.3.1.running_mean", "image_model.features.2.1.block.3.1.running_var", "image_model.features.2.2.block.0.0.weight", "image_model.features.2.2.block.0.1.weight", "image_model.features.2.2.block.0.1.bias", "image_model.features.2.2.block.0.1.running_mean", "image_model.features.2.2.block.0.1.running_var", "image_model.features.2.2.block.1.0.weight", "image_model.features.2.2.block.1.1.weight", "image_model.features.2.2.block.1.1.bias", "image_model.features.2.2.block.1.1.running_mean", "image_model.features.2.2.block.1.1.running_var", "image_model.features.2.2.block.2.fc1.weight", "image_model.features.2.2.block.2.fc1.bias", "image_model.features.2.2.block.2.fc2.weight", "image_model.features.2.2.block.2.fc2.bias", "image_model.features.2.2.block.3.0.weight", "image_model.features.2.2.block.3.1.weight", "image_model.features.2.2.block.3.1.bias", "image_model.features.2.2.block.3.1.running_mean", "image_model.features.2.2.block.3.1.running_var", "image_model.features.3.0.block.0.0.weight", "image_model.features.3.0.block.0.1.weight", "image_model.features.3.0.block.0.1.bias", "image_model.features.3.0.block.0.1.running_mean", "image_model.features.3.0.block.0.1.running_var", "image_model.features.3.0.block.1.0.weight", "image_model.features.3.0.block.1.1.weight", "image_model.features.3.0.block.1.1.bias", "image_model.features.3.0.block.1.1.running_mean", "image_model.features.3.0.block.1.1.running_var", "image_model.features.3.0.block.2.fc1.weight", "image_model.features.3.0.block.2.fc1.bias", "image_model.features.3.0.block.2.fc2.weight", "image_model.features.3.0.block.2.fc2.bias", "image_model.features.3.0.block.3.0.weight", "image_model.features.3.0.block.3.1.weight", "image_model.features.3.0.block.3.1.bias", "image_model.features.3.0.block.3.1.running_mean", "image_model.features.3.0.block.3.1.running_var", "image_model.features.3.1.block.0.0.weight", "image_model.features.3.1.block.0.1.weight", "image_model.features.3.1.block.0.1.bias", "image_model.features.3.1.block.0.1.running_mean", "image_model.features.3.1.block.0.1.running_var", "image_model.features.3.1.block.1.0.weight", "image_model.features.3.1.block.1.1.weight", "image_model.features.3.1.block.1.1.bias", "image_model.features.3.1.block.1.1.running_mean", "image_model.features.3.1.block.1.1.running_var", "image_model.features.3.1.block.2.fc1.weight", "image_model.features.3.1.block.2.fc1.bias", "image_model.features.3.1.block.2.fc2.weight", "image_model.features.3.1.block.2.fc2.bias", "image_model.features.3.1.block.3.0.weight", "image_model.features.3.1.block.3.1.weight", "image_model.features.3.1.block.3.1.bias", "image_model.features.3.1.block.3.1.running_mean", "image_model.features.3.1.block.3.1.running_var", "image_model.features.3.2.block.0.0.weight", "image_model.features.3.2.block.0.1.weight", "image_model.features.3.2.block.0.1.bias", "image_model.features.3.2.block.0.1.running_mean", "image_model.features.3.2.block.0.1.running_var", "image_model.features.3.2.block.1.0.weight", "image_model.features.3.2.block.1.1.weight", "image_model.features.3.2.block.1.1.bias", "image_model.features.3.2.block.1.1.running_mean", "image_model.features.3.2.block.1.1.running_var", "image_model.features.3.2.block.2.fc1.weight", "image_model.features.3.2.block.2.fc1.bias", "image_model.features.3.2.block.2.fc2.weight", "image_model.features.3.2.block.2.fc2.bias", "image_model.features.3.2.block.3.0.weight", "image_model.features.3.2.block.3.1.weight", "image_model.features.3.2.block.3.1.bias", "image_model.features.3.2.block.3.1.running_mean", "image_model.features.3.2.block.3.1.running_var", "image_model.features.4.0.block.0.0.weight", "image_model.features.4.0.block.0.1.weight", "image_model.features.4.0.block.0.1.bias", "image_model.features.4.0.block.0.1.running_mean", "image_model.features.4.0.block.0.1.running_var", "image_model.features.4.0.block.1.0.weight", "image_model.features.4.0.block.1.1.weight", "image_model.features.4.0.block.1.1.bias", "image_model.features.4.0.block.1.1.running_mean", "image_model.features.4.0.block.1.1.running_var", "image_model.features.4.0.block.2.fc1.weight", "image_model.features.4.0.block.2.fc1.bias", "image_model.features.4.0.block.2.fc2.weight", "image_model.features.4.0.block.2.fc2.bias", "image_model.features.4.0.block.3.0.weight", "image_model.features.4.0.block.3.1.weight", "image_model.features.4.0.block.3.1.bias", "image_model.features.4.0.block.3.1.running_mean", "image_model.features.4.0.block.3.1.running_var", "image_model.features.4.1.block.0.0.weight", "image_model.features.4.1.block.0.1.weight", "image_model.features.4.1.block.0.1.bias", "image_model.features.4.1.block.0.1.running_mean", "image_model.features.4.1.block.0.1.running_var", "image_model.features.4.1.block.1.0.weight", "image_model.features.4.1.block.1.1.weight", "image_model.features.4.1.block.1.1.bias", "image_model.features.4.1.block.1.1.running_mean", "image_model.features.4.1.block.1.1.running_var", "image_model.features.4.1.block.2.fc1.weight", "image_model.features.4.1.block.2.fc1.bias", "image_model.features.4.1.block.2.fc2.weight", "image_model.features.4.1.block.2.fc2.bias", "image_model.features.4.1.block.3.0.weight", "image_model.features.4.1.block.3.1.weight", "image_model.features.4.1.block.3.1.bias", "image_model.features.4.1.block.3.1.running_mean", "image_model.features.4.1.block.3.1.running_var", "image_model.features.4.2.block.0.0.weight", "image_model.features.4.2.block.0.1.weight", "image_model.features.4.2.block.0.1.bias", "image_model.features.4.2.block.0.1.running_mean", "image_model.features.4.2.block.0.1.running_var", "image_model.features.4.2.block.1.0.weight", "image_model.features.4.2.block.1.1.weight", "image_model.features.4.2.block.1.1.bias", "image_model.features.4.2.block.1.1.running_mean", "image_model.features.4.2.block.1.1.running_var", "image_model.features.4.2.block.2.fc1.weight", "image_model.features.4.2.block.2.fc1.bias", "image_model.features.4.2.block.2.fc2.weight", "image_model.features.4.2.block.2.fc2.bias", "image_model.features.4.2.block.3.0.weight", "image_model.features.4.2.block.3.1.weight", "image_model.features.4.2.block.3.1.bias", "image_model.features.4.2.block.3.1.running_mean", "image_model.features.4.2.block.3.1.running_var", "image_model.features.4.3.block.0.0.weight", "image_model.features.4.3.block.0.1.weight", "image_model.features.4.3.block.0.1.bias", "image_model.features.4.3.block.0.1.running_mean", "image_model.features.4.3.block.0.1.running_var", "image_model.features.4.3.block.1.0.weight", "image_model.features.4.3.block.1.1.weight", "image_model.features.4.3.block.1.1.bias", "image_model.features.4.3.block.1.1.running_mean", "image_model.features.4.3.block.1.1.running_var", "image_model.features.4.3.block.2.fc1.weight", "image_model.features.4.3.block.2.fc1.bias", "image_model.features.4.3.block.2.fc2.weight", "image_model.features.4.3.block.2.fc2.bias", "image_model.features.4.3.block.3.0.weight", "image_model.features.4.3.block.3.1.weight", "image_model.features.4.3.block.3.1.bias", "image_model.features.4.3.block.3.1.running_mean", "image_model.features.4.3.block.3.1.running_var", "image_model.features.5.0.block.0.0.weight", "image_model.features.5.0.block.0.1.weight", "image_model.features.5.0.block.0.1.bias", "image_model.features.5.0.block.0.1.running_mean", "image_model.features.5.0.block.0.1.running_var", "image_model.features.5.0.block.1.0.weight", "image_model.features.5.0.block.1.1.weight", "image_model.features.5.0.block.1.1.bias", "image_model.features.5.0.block.1.1.running_mean", "image_model.features.5.0.block.1.1.running_var", "image_model.features.5.0.block.2.fc1.weight", "image_model.features.5.0.block.2.fc1.bias", "image_model.features.5.0.block.2.fc2.weight", "image_model.features.5.0.block.2.fc2.bias", "image_model.features.5.0.block.3.0.weight", "image_model.features.5.0.block.3.1.weight", "image_model.features.5.0.block.3.1.bias", "image_model.features.5.0.block.3.1.running_mean", "image_model.features.5.0.block.3.1.running_var", "image_model.features.5.1.block.0.0.weight", "image_model.features.5.1.block.0.1.weight", "image_model.features.5.1.block.0.1.bias", "image_model.features.5.1.block.0.1.running_mean", "image_model.features.5.1.block.0.1.running_var", "image_model.features.5.1.block.1.0.weight", "image_model.features.5.1.block.1.1.weight", "image_model.features.5.1.block.1.1.bias", "image_model.features.5.1.block.1.1.running_mean", "image_model.features.5.1.block.1.1.running_var", "image_model.features.5.1.block.2.fc1.weight", "image_model.features.5.1.block.2.fc1.bias", "image_model.features.5.1.block.2.fc2.weight", "image_model.features.5.1.block.2.fc2.bias", "image_model.features.5.1.block.3.0.weight", "image_model.features.5.1.block.3.1.weight", "image_model.features.5.1.block.3.1.bias", "image_model.features.5.1.block.3.1.running_mean", "image_model.features.5.1.block.3.1.running_var", "image_model.features.5.2.block.0.0.weight", "image_model.features.5.2.block.0.1.weight", "image_model.features.5.2.block.0.1.bias", "image_model.features.5.2.block.0.1.running_mean", "image_model.features.5.2.block.0.1.running_var", "image_model.features.5.2.block.1.0.weight", "image_model.features.5.2.block.1.1.weight", "image_model.features.5.2.block.1.1.bias", "image_model.features.5.2.block.1.1.running_mean", "image_model.features.5.2.block.1.1.running_var", "image_model.features.5.2.block.2.fc1.weight", "image_model.features.5.2.block.2.fc1.bias", "image_model.features.5.2.block.2.fc2.weight", "image_model.features.5.2.block.2.fc2.bias", "image_model.features.5.2.block.3.0.weight", "image_model.features.5.2.block.3.1.weight", "image_model.features.5.2.block.3.1.bias", "image_model.features.5.2.block.3.1.running_mean", "image_model.features.5.2.block.3.1.running_var", "image_model.features.5.3.block.0.0.weight", "image_model.features.5.3.block.0.1.weight", "image_model.features.5.3.block.0.1.bias", "image_model.features.5.3.block.0.1.running_mean", "image_model.features.5.3.block.0.1.running_var", "image_model.features.5.3.block.1.0.weight", "image_model.features.5.3.block.1.1.weight", "image_model.features.5.3.block.1.1.bias", "image_model.features.5.3.block.1.1.running_mean", "image_model.features.5.3.block.1.1.running_var", "image_model.features.5.3.block.2.fc1.weight", "image_model.features.5.3.block.2.fc1.bias", "image_model.features.5.3.block.2.fc2.weight", "image_model.features.5.3.block.2.fc2.bias", "image_model.features.5.3.block.3.0.weight", "image_model.features.5.3.block.3.1.weight", "image_model.features.5.3.block.3.1.bias", "image_model.features.5.3.block.3.1.running_mean", "image_model.features.5.3.block.3.1.running_var", "image_model.features.6.0.block.0.0.weight", "image_model.features.6.0.block.0.1.weight", "image_model.features.6.0.block.0.1.bias", "image_model.features.6.0.block.0.1.running_mean", "image_model.features.6.0.block.0.1.running_var", "image_model.features.6.0.block.1.0.weight", "image_model.features.6.0.block.1.1.weight", "image_model.features.6.0.block.1.1.bias", "image_model.features.6.0.block.1.1.running_mean", "image_model.features.6.0.block.1.1.running_var", "image_model.features.6.0.block.2.fc1.weight", "image_model.features.6.0.block.2.fc1.bias", "image_model.features.6.0.block.2.fc2.weight", "image_model.features.6.0.block.2.fc2.bias", "image_model.features.6.0.block.3.0.weight", "image_model.features.6.0.block.3.1.weight", "image_model.features.6.0.block.3.1.bias", "image_model.features.6.0.block.3.1.running_mean", "image_model.features.6.0.block.3.1.running_var", "image_model.features.6.1.block.0.0.weight", "image_model.features.6.1.block.0.1.weight", "image_model.features.6.1.block.0.1.bias", "image_model.features.6.1.block.0.1.running_mean", "image_model.features.6.1.block.0.1.running_var", "image_model.features.6.1.block.1.0.weight", "image_model.features.6.1.block.1.1.weight", "image_model.features.6.1.block.1.1.bias", "image_model.features.6.1.block.1.1.running_mean", "image_model.features.6.1.block.1.1.running_var", "image_model.features.6.1.block.2.fc1.weight", "image_model.features.6.1.block.2.fc1.bias", "image_model.features.6.1.block.2.fc2.weight", "image_model.features.6.1.block.2.fc2.bias", "image_model.features.6.1.block.3.0.weight", "image_model.features.6.1.block.3.1.weight", "image_model.features.6.1.block.3.1.bias", "image_model.features.6.1.block.3.1.running_mean", "image_model.features.6.1.block.3.1.running_var", "image_model.features.6.2.block.0.0.weight", "image_model.features.6.2.block.0.1.weight", "image_model.features.6.2.block.0.1.bias", "image_model.features.6.2.block.0.1.running_mean", "image_model.features.6.2.block.0.1.running_var", "image_model.features.6.2.block.1.0.weight", "image_model.features.6.2.block.1.1.weight", "image_model.features.6.2.block.1.1.bias", "image_model.features.6.2.block.1.1.running_mean", "image_model.features.6.2.block.1.1.running_var", "image_model.features.6.2.block.2.fc1.weight", "image_model.features.6.2.block.2.fc1.bias", "image_model.features.6.2.block.2.fc2.weight", "image_model.features.6.2.block.2.fc2.bias", "image_model.features.6.2.block.3.0.weight", "image_model.features.6.2.block.3.1.weight", "image_model.features.6.2.block.3.1.bias", "image_model.features.6.2.block.3.1.running_mean", "image_model.features.6.2.block.3.1.running_var", "image_model.features.6.3.block.0.0.weight", "image_model.features.6.3.block.0.1.weight", "image_model.features.6.3.block.0.1.bias", "image_model.features.6.3.block.0.1.running_mean", "image_model.features.6.3.block.0.1.running_var", "image_model.features.6.3.block.1.0.weight", "image_model.features.6.3.block.1.1.weight", "image_model.features.6.3.block.1.1.bias", "image_model.features.6.3.block.1.1.running_mean", "image_model.features.6.3.block.1.1.running_var", "image_model.features.6.3.block.2.fc1.weight", "image_model.features.6.3.block.2.fc1.bias", "image_model.features.6.3.block.2.fc2.weight", "image_model.features.6.3.block.2.fc2.bias", "image_model.features.6.3.block.3.0.weight", "image_model.features.6.3.block.3.1.weight", "image_model.features.6.3.block.3.1.bias", "image_model.features.6.3.block.3.1.running_mean", "image_model.features.6.3.block.3.1.running_var", "image_model.features.6.4.block.0.0.weight", "image_model.features.6.4.block.0.1.weight", "image_model.features.6.4.block.0.1.bias", "image_model.features.6.4.block.0.1.running_mean", "image_model.features.6.4.block.0.1.running_var", "image_model.features.6.4.block.1.0.weight", "image_model.features.6.4.block.1.1.weight", "image_model.features.6.4.block.1.1.bias", "image_model.features.6.4.block.1.1.running_mean", "image_model.features.6.4.block.1.1.running_var", "image_model.features.6.4.block.2.fc1.weight", "image_model.features.6.4.block.2.fc1.bias", "image_model.features.6.4.block.2.fc2.weight", "image_model.features.6.4.block.2.fc2.bias", "image_model.features.6.4.block.3.0.weight", "image_model.features.6.4.block.3.1.weight", "image_model.features.6.4.block.3.1.bias", "image_model.features.6.4.block.3.1.running_mean", "image_model.features.6.4.block.3.1.running_var", "image_model.features.7.0.block.0.0.weight", "image_model.features.7.0.block.0.1.weight", "image_model.features.7.0.block.0.1.bias", "image_model.features.7.0.block.0.1.running_mean", "image_model.features.7.0.block.0.1.running_var", "image_model.features.7.0.block.1.0.weight", "image_model.features.7.0.block.1.1.weight", "image_model.features.7.0.block.1.1.bias", "image_model.features.7.0.block.1.1.running_mean", "image_model.features.7.0.block.1.1.running_var", "image_model.features.7.0.block.2.fc1.weight", "image_model.features.7.0.block.2.fc1.bias", "image_model.features.7.0.block.2.fc2.weight", "image_model.features.7.0.block.2.fc2.bias", "image_model.features.7.0.block.3.0.weight", "image_model.features.7.0.block.3.1.weight", "image_model.features.7.0.block.3.1.bias", "image_model.features.7.0.block.3.1.running_mean", "image_model.features.7.0.block.3.1.running_var", "image_model.features.7.1.block.0.0.weight", "image_model.features.7.1.block.0.1.weight", "image_model.features.7.1.block.0.1.bias", "image_model.features.7.1.block.0.1.running_mean", "image_model.features.7.1.block.0.1.running_var", "image_model.features.7.1.block.1.0.weight", "image_model.features.7.1.block.1.1.weight", "image_model.features.7.1.block.1.1.bias", "image_model.features.7.1.block.1.1.running_mean", "image_model.features.7.1.block.1.1.running_var", "image_model.features.7.1.block.2.fc1.weight", "image_model.features.7.1.block.2.fc1.bias", "image_model.features.7.1.block.2.fc2.weight", "image_model.features.7.1.block.2.fc2.bias", "image_model.features.7.1.block.3.0.weight", "image_model.features.7.1.block.3.1.weight", "image_model.features.7.1.block.3.1.bias", "image_model.features.7.1.block.3.1.running_mean", "image_model.features.7.1.block.3.1.running_var", "image_model.features.8.0.weight", "image_model.features.8.1.weight", "image_model.features.8.1.bias", "image_model.features.8.1.running_mean", "image_model.features.8.1.running_var", "meta_net.0.weight", "meta_net.0.bias", "meta_net.1.weight", "meta_net.1.bias", "meta_net.1.running_mean", "meta_net.1.running_var", "meta_net.4.weight", "meta_net.4.bias", "meta_net.5.weight", "meta_net.5.bias", "meta_net.5.running_mean", "meta_net.5.running_var", "classifier.0.weight", "classifier.0.bias", "classifier.1.weight", "classifier.1.bias", "classifier.1.running_mean", "classifier.1.running_var", "classifier.4.weight", "classifier.4.bias". 
	Unexpected key(s) in state_dict: "features.0.0.weight", "features.0.1.weight", "features.0.1.bias", "features.0.1.running_mean", "features.0.1.running_var", "features.0.1.num_batches_tracked", "features.1.0.block.0.0.weight", "features.1.0.block.0.1.weight", "features.1.0.block.0.1.bias", "features.1.0.block.0.1.running_mean", "features.1.0.block.0.1.running_var", "features.1.0.block.0.1.num_batches_tracked", "features.1.0.block.1.fc1.weight", "features.1.0.block.1.fc1.bias", "features.1.0.block.1.fc2.weight", "features.1.0.block.1.fc2.bias", "features.1.0.block.2.0.weight", "features.1.0.block.2.1.weight", "features.1.0.block.2.1.bias", "features.1.0.block.2.1.running_mean", "features.1.0.block.2.1.running_var", "features.1.0.block.2.1.num_batches_tracked", "features.1.1.block.0.0.weight", "features.1.1.block.0.1.weight", "features.1.1.block.0.1.bias", "features.1.1.block.0.1.running_mean", "features.1.1.block.0.1.running_var", "features.1.1.block.0.1.num_batches_tracked", "features.1.1.block.1.fc1.weight", "features.1.1.block.1.fc1.bias", "features.1.1.block.1.fc2.weight", "features.1.1.block.1.fc2.bias", "features.1.1.block.2.0.weight", "features.1.1.block.2.1.weight", "features.1.1.block.2.1.bias", "features.1.1.block.2.1.running_mean", "features.1.1.block.2.1.running_var", "features.1.1.block.2.1.num_batches_tracked", "features.2.0.block.0.0.weight", "features.2.0.block.0.1.weight", "features.2.0.block.0.1.bias", "features.2.0.block.0.1.running_mean", "features.2.0.block.0.1.running_var", "features.2.0.block.0.1.num_batches_tracked", "features.2.0.block.1.0.weight", "features.2.0.block.1.1.weight", "features.2.0.block.1.1.bias", "features.2.0.block.1.1.running_mean", "features.2.0.block.1.1.running_var", "features.2.0.block.1.1.num_batches_tracked", "features.2.0.block.2.fc1.weight", "features.2.0.block.2.fc1.bias", "features.2.0.block.2.fc2.weight", "features.2.0.block.2.fc2.bias", "features.2.0.block.3.0.weight", "features.2.0.block.3.1.weight", "features.2.0.block.3.1.bias", "features.2.0.block.3.1.running_mean", "features.2.0.block.3.1.running_var", "features.2.0.block.3.1.num_batches_tracked", "features.2.1.block.0.0.weight", "features.2.1.block.0.1.weight", "features.2.1.block.0.1.bias", "features.2.1.block.0.1.running_mean", "features.2.1.block.0.1.running_var", "features.2.1.block.0.1.num_batches_tracked", "features.2.1.block.1.0.weight", "features.2.1.block.1.1.weight", "features.2.1.block.1.1.bias", "features.2.1.block.1.1.running_mean", "features.2.1.block.1.1.running_var", "features.2.1.block.1.1.num_batches_tracked", "features.2.1.block.2.fc1.weight", "features.2.1.block.2.fc1.bias", "features.2.1.block.2.fc2.weight", "features.2.1.block.2.fc2.bias", "features.2.1.block.3.0.weight", "features.2.1.block.3.1.weight", "features.2.1.block.3.1.bias", "features.2.1.block.3.1.running_mean", "features.2.1.block.3.1.running_var", "features.2.1.block.3.1.num_batches_tracked", "features.2.2.block.0.0.weight", "features.2.2.block.0.1.weight", "features.2.2.block.0.1.bias", "features.2.2.block.0.1.running_mean", "features.2.2.block.0.1.running_var", "features.2.2.block.0.1.num_batches_tracked", "features.2.2.block.1.0.weight", "features.2.2.block.1.1.weight", "features.2.2.block.1.1.bias", "features.2.2.block.1.1.running_mean", "features.2.2.block.1.1.running_var", "features.2.2.block.1.1.num_batches_tracked", "features.2.2.block.2.fc1.weight", "features.2.2.block.2.fc1.bias", "features.2.2.block.2.fc2.weight", "features.2.2.block.2.fc2.bias", "features.2.2.block.3.0.weight", "features.2.2.block.3.1.weight", "features.2.2.block.3.1.bias", "features.2.2.block.3.1.running_mean", "features.2.2.block.3.1.running_var", "features.2.2.block.3.1.num_batches_tracked", "features.3.0.block.0.0.weight", "features.3.0.block.0.1.weight", "features.3.0.block.0.1.bias", "features.3.0.block.0.1.running_mean", "features.3.0.block.0.1.running_var", "features.3.0.block.0.1.num_batches_tracked", "features.3.0.block.1.0.weight", "features.3.0.block.1.1.weight", "features.3.0.block.1.1.bias", "features.3.0.block.1.1.running_mean", "features.3.0.block.1.1.running_var", "features.3.0.block.1.1.num_batches_tracked", "features.3.0.block.2.fc1.weight", "features.3.0.block.2.fc1.bias", "features.3.0.block.2.fc2.weight", "features.3.0.block.2.fc2.bias", "features.3.0.block.3.0.weight", "features.3.0.block.3.1.weight", "features.3.0.block.3.1.bias", "features.3.0.block.3.1.running_mean", "features.3.0.block.3.1.running_var", "features.3.0.block.3.1.num_batches_tracked", "features.3.1.block.0.0.weight", "features.3.1.block.0.1.weight", "features.3.1.block.0.1.bias", "features.3.1.block.0.1.running_mean", "features.3.1.block.0.1.running_var", "features.3.1.block.0.1.num_batches_tracked", "features.3.1.block.1.0.weight", "features.3.1.block.1.1.weight", "features.3.1.block.1.1.bias", "features.3.1.block.1.1.running_mean", "features.3.1.block.1.1.running_var", "features.3.1.block.1.1.num_batches_tracked", "features.3.1.block.2.fc1.weight", "features.3.1.block.2.fc1.bias", "features.3.1.block.2.fc2.weight", "features.3.1.block.2.fc2.bias", "features.3.1.block.3.0.weight", "features.3.1.block.3.1.weight", "features.3.1.block.3.1.bias", "features.3.1.block.3.1.running_mean", "features.3.1.block.3.1.running_var", "features.3.1.block.3.1.num_batches_tracked", "features.3.2.block.0.0.weight", "features.3.2.block.0.1.weight", "features.3.2.block.0.1.bias", "features.3.2.block.0.1.running_mean", "features.3.2.block.0.1.running_var", "features.3.2.block.0.1.num_batches_tracked", "features.3.2.block.1.0.weight", "features.3.2.block.1.1.weight", "features.3.2.block.1.1.bias", "features.3.2.block.1.1.running_mean", "features.3.2.block.1.1.running_var", "features.3.2.block.1.1.num_batches_tracked", "features.3.2.block.2.fc1.weight", "features.3.2.block.2.fc1.bias", "features.3.2.block.2.fc2.weight", "features.3.2.block.2.fc2.bias", "features.3.2.block.3.0.weight", "features.3.2.block.3.1.weight", "features.3.2.block.3.1.bias", "features.3.2.block.3.1.running_mean", "features.3.2.block.3.1.running_var", "features.3.2.block.3.1.num_batches_tracked", "features.4.0.block.0.0.weight", "features.4.0.block.0.1.weight", "features.4.0.block.0.1.bias", "features.4.0.block.0.1.running_mean", "features.4.0.block.0.1.running_var", "features.4.0.block.0.1.num_batches_tracked", "features.4.0.block.1.0.weight", "features.4.0.block.1.1.weight", "features.4.0.block.1.1.bias", "features.4.0.block.1.1.running_mean", "features.4.0.block.1.1.running_var", "features.4.0.block.1.1.num_batches_tracked", "features.4.0.block.2.fc1.weight", "features.4.0.block.2.fc1.bias", "features.4.0.block.2.fc2.weight", "features.4.0.block.2.fc2.bias", "features.4.0.block.3.0.weight", "features.4.0.block.3.1.weight", "features.4.0.block.3.1.bias", "features.4.0.block.3.1.running_mean", "features.4.0.block.3.1.running_var", "features.4.0.block.3.1.num_batches_tracked", "features.4.1.block.0.0.weight", "features.4.1.block.0.1.weight", "features.4.1.block.0.1.bias", "features.4.1.block.0.1.running_mean", "features.4.1.block.0.1.running_var", "features.4.1.block.0.1.num_batches_tracked", "features.4.1.block.1.0.weight", "features.4.1.block.1.1.weight", "features.4.1.block.1.1.bias", "features.4.1.block.1.1.running_mean", "features.4.1.block.1.1.running_var", "features.4.1.block.1.1.num_batches_tracked", "features.4.1.block.2.fc1.weight", "features.4.1.block.2.fc1.bias", "features.4.1.block.2.fc2.weight", "features.4.1.block.2.fc2.bias", "features.4.1.block.3.0.weight", "features.4.1.block.3.1.weight", "features.4.1.block.3.1.bias", "features.4.1.block.3.1.running_mean", "features.4.1.block.3.1.running_var", "features.4.1.block.3.1.num_batches_tracked", "features.4.2.block.0.0.weight", "features.4.2.block.0.1.weight", "features.4.2.block.0.1.bias", "features.4.2.block.0.1.running_mean", "features.4.2.block.0.1.running_var", "features.4.2.block.0.1.num_batches_tracked", "features.4.2.block.1.0.weight", "features.4.2.block.1.1.weight", "features.4.2.block.1.1.bias", "features.4.2.block.1.1.running_mean", "features.4.2.block.1.1.running_var", "features.4.2.block.1.1.num_batches_tracked", "features.4.2.block.2.fc1.weight", "features.4.2.block.2.fc1.bias", "features.4.2.block.2.fc2.weight", "features.4.2.block.2.fc2.bias", "features.4.2.block.3.0.weight", "features.4.2.block.3.1.weight", "features.4.2.block.3.1.bias", "features.4.2.block.3.1.running_mean", "features.4.2.block.3.1.running_var", "features.4.2.block.3.1.num_batches_tracked", "features.4.3.block.0.0.weight", "features.4.3.block.0.1.weight", "features.4.3.block.0.1.bias", "features.4.3.block.0.1.running_mean", "features.4.3.block.0.1.running_var", "features.4.3.block.0.1.num_batches_tracked", "features.4.3.block.1.0.weight", "features.4.3.block.1.1.weight", "features.4.3.block.1.1.bias", "features.4.3.block.1.1.running_mean", "features.4.3.block.1.1.running_var", "features.4.3.block.1.1.num_batches_tracked", "features.4.3.block.2.fc1.weight", "features.4.3.block.2.fc1.bias", "features.4.3.block.2.fc2.weight", "features.4.3.block.2.fc2.bias", "features.4.3.block.3.0.weight", "features.4.3.block.3.1.weight", "features.4.3.block.3.1.bias", "features.4.3.block.3.1.running_mean", "features.4.3.block.3.1.running_var", "features.4.3.block.3.1.num_batches_tracked", "features.5.0.block.0.0.weight", "features.5.0.block.0.1.weight", "features.5.0.block.0.1.bias", "features.5.0.block.0.1.running_mean", "features.5.0.block.0.1.running_var", "features.5.0.block.0.1.num_batches_tracked", "features.5.0.block.1.0.weight", "features.5.0.block.1.1.weight", "features.5.0.block.1.1.bias", "features.5.0.block.1.1.running_mean", "features.5.0.block.1.1.running_var", "features.5.0.block.1.1.num_batches_tracked", "features.5.0.block.2.fc1.weight", "features.5.0.block.2.fc1.bias", "features.5.0.block.2.fc2.weight", "features.5.0.block.2.fc2.bias", "features.5.0.block.3.0.weight", "features.5.0.block.3.1.weight", "features.5.0.block.3.1.bias", "features.5.0.block.3.1.running_mean", "features.5.0.block.3.1.running_var", "features.5.0.block.3.1.num_batches_tracked", "features.5.1.block.0.0.weight", "features.5.1.block.0.1.weight", "features.5.1.block.0.1.bias", "features.5.1.block.0.1.running_mean", "features.5.1.block.0.1.running_var", "features.5.1.block.0.1.num_batches_tracked", "features.5.1.block.1.0.weight", "features.5.1.block.1.1.weight", "features.5.1.block.1.1.bias", "features.5.1.block.1.1.running_mean", "features.5.1.block.1.1.running_var", "features.5.1.block.1.1.num_batches_tracked", "features.5.1.block.2.fc1.weight", "features.5.1.block.2.fc1.bias", "features.5.1.block.2.fc2.weight", "features.5.1.block.2.fc2.bias", "features.5.1.block.3.0.weight", "features.5.1.block.3.1.weight", "features.5.1.block.3.1.bias", "features.5.1.block.3.1.running_mean", "features.5.1.block.3.1.running_var", "features.5.1.block.3.1.num_batches_tracked", "features.5.2.block.0.0.weight", "features.5.2.block.0.1.weight", "features.5.2.block.0.1.bias", "features.5.2.block.0.1.running_mean", "features.5.2.block.0.1.running_var", "features.5.2.block.0.1.num_batches_tracked", "features.5.2.block.1.0.weight", "features.5.2.block.1.1.weight", "features.5.2.block.1.1.bias", "features.5.2.block.1.1.running_mean", "features.5.2.block.1.1.running_var", "features.5.2.block.1.1.num_batches_tracked", "features.5.2.block.2.fc1.weight", "features.5.2.block.2.fc1.bias", "features.5.2.block.2.fc2.weight", "features.5.2.block.2.fc2.bias", "features.5.2.block.3.0.weight", "features.5.2.block.3.1.weight", "features.5.2.block.3.1.bias", "features.5.2.block.3.1.running_mean", "features.5.2.block.3.1.running_var", "features.5.2.block.3.1.num_batches_tracked", "features.5.3.block.0.0.weight", "features.5.3.block.0.1.weight", "features.5.3.block.0.1.bias", "features.5.3.block.0.1.running_mean", "features.5.3.block.0.1.running_var", "features.5.3.block.0.1.num_batches_tracked", "features.5.3.block.1.0.weight", "features.5.3.block.1.1.weight", "features.5.3.block.1.1.bias", "features.5.3.block.1.1.running_mean", "features.5.3.block.1.1.running_var", "features.5.3.block.1.1.num_batches_tracked", "features.5.3.block.2.fc1.weight", "features.5.3.block.2.fc1.bias", "features.5.3.block.2.fc2.weight", "features.5.3.block.2.fc2.bias", "features.5.3.block.3.0.weight", "features.5.3.block.3.1.weight", "features.5.3.block.3.1.bias", "features.5.3.block.3.1.running_mean", "features.5.3.block.3.1.running_var", "features.5.3.block.3.1.num_batches_tracked", "features.6.0.block.0.0.weight", "features.6.0.block.0.1.weight", "features.6.0.block.0.1.bias", "features.6.0.block.0.1.running_mean", "features.6.0.block.0.1.running_var", "features.6.0.block.0.1.num_batches_tracked", "features.6.0.block.1.0.weight", "features.6.0.block.1.1.weight", "features.6.0.block.1.1.bias", "features.6.0.block.1.1.running_mean", "features.6.0.block.1.1.running_var", "features.6.0.block.1.1.num_batches_tracked", "features.6.0.block.2.fc1.weight", "features.6.0.block.2.fc1.bias", "features.6.0.block.2.fc2.weight", "features.6.0.block.2.fc2.bias", "features.6.0.block.3.0.weight", "features.6.0.block.3.1.weight", "features.6.0.block.3.1.bias", "features.6.0.block.3.1.running_mean", "features.6.0.block.3.1.running_var", "features.6.0.block.3.1.num_batches_tracked", "features.6.1.block.0.0.weight", "features.6.1.block.0.1.weight", "features.6.1.block.0.1.bias", "features.6.1.block.0.1.running_mean", "features.6.1.block.0.1.running_var", "features.6.1.block.0.1.num_batches_tracked", "features.6.1.block.1.0.weight", "features.6.1.block.1.1.weight", "features.6.1.block.1.1.bias", "features.6.1.block.1.1.running_mean", "features.6.1.block.1.1.running_var", "features.6.1.block.1.1.num_batches_tracked", "features.6.1.block.2.fc1.weight", "features.6.1.block.2.fc1.bias", "features.6.1.block.2.fc2.weight", "features.6.1.block.2.fc2.bias", "features.6.1.block.3.0.weight", "features.6.1.block.3.1.weight", "features.6.1.block.3.1.bias", "features.6.1.block.3.1.running_mean", "features.6.1.block.3.1.running_var", "features.6.1.block.3.1.num_batches_tracked", "features.6.2.block.0.0.weight", "features.6.2.block.0.1.weight", "features.6.2.block.0.1.bias", "features.6.2.block.0.1.running_mean", "features.6.2.block.0.1.running_var", "features.6.2.block.0.1.num_batches_tracked", "features.6.2.block.1.0.weight", "features.6.2.block.1.1.weight", "features.6.2.block.1.1.bias", "features.6.2.block.1.1.running_mean", "features.6.2.block.1.1.running_var", "features.6.2.block.1.1.num_batches_tracked", "features.6.2.block.2.fc1.weight", "features.6.2.block.2.fc1.bias", "features.6.2.block.2.fc2.weight", "features.6.2.block.2.fc2.bias", "features.6.2.block.3.0.weight", "features.6.2.block.3.1.weight", "features.6.2.block.3.1.bias", "features.6.2.block.3.1.running_mean", "features.6.2.block.3.1.running_var", "features.6.2.block.3.1.num_batches_tracked", "features.6.3.block.0.0.weight", "features.6.3.block.0.1.weight", "features.6.3.block.0.1.bias", "features.6.3.block.0.1.running_mean", "features.6.3.block.0.1.running_var", "features.6.3.block.0.1.num_batches_tracked", "features.6.3.block.1.0.weight", "features.6.3.block.1.1.weight", "features.6.3.block.1.1.bias", "features.6.3.block.1.1.running_mean", "features.6.3.block.1.1.running_var", "features.6.3.block.1.1.num_batches_tracked", "features.6.3.block.2.fc1.weight", "features.6.3.block.2.fc1.bias", "features.6.3.block.2.fc2.weight", "features.6.3.block.2.fc2.bias", "features.6.3.block.3.0.weight", "features.6.3.block.3.1.weight", "features.6.3.block.3.1.bias", "features.6.3.block.3.1.running_mean", "features.6.3.block.3.1.running_var", "features.6.3.block.3.1.num_batches_tracked", "features.6.4.block.0.0.weight", "features.6.4.block.0.1.weight", "features.6.4.block.0.1.bias", "features.6.4.block.0.1.running_mean", "features.6.4.block.0.1.running_var", "features.6.4.block.0.1.num_batches_tracked", "features.6.4.block.1.0.weight", "features.6.4.block.1.1.weight", "features.6.4.block.1.1.bias", "features.6.4.block.1.1.running_mean", "features.6.4.block.1.1.running_var", "features.6.4.block.1.1.num_batches_tracked", "features.6.4.block.2.fc1.weight", "features.6.4.block.2.fc1.bias", "features.6.4.block.2.fc2.weight", "features.6.4.block.2.fc2.bias", "features.6.4.block.3.0.weight", "features.6.4.block.3.1.weight", "features.6.4.block.3.1.bias", "features.6.4.block.3.1.running_mean", "features.6.4.block.3.1.running_var", "features.6.4.block.3.1.num_batches_tracked", "features.7.0.block.0.0.weight", "features.7.0.block.0.1.weight", "features.7.0.block.0.1.bias", "features.7.0.block.0.1.running_mean", "features.7.0.block.0.1.running_var", "features.7.0.block.0.1.num_batches_tracked", "features.7.0.block.1.0.weight", "features.7.0.block.1.1.weight", "features.7.0.block.1.1.bias", "features.7.0.block.1.1.running_mean", "features.7.0.block.1.1.running_var", "features.7.0.block.1.1.num_batches_tracked", "features.7.0.block.2.fc1.weight", "features.7.0.block.2.fc1.bias", "features.7.0.block.2.fc2.weight", "features.7.0.block.2.fc2.bias", "features.7.0.block.3.0.weight", "features.7.0.block.3.1.weight", "features.7.0.block.3.1.bias", "features.7.0.block.3.1.running_mean", "features.7.0.block.3.1.running_var", "features.7.0.block.3.1.num_batches_tracked", "features.7.1.block.0.0.weight", "features.7.1.block.0.1.weight", "features.7.1.block.0.1.bias", "features.7.1.block.0.1.running_mean", "features.7.1.block.0.1.running_var", "features.7.1.block.0.1.num_batches_tracked", "features.7.1.block.1.0.weight", "features.7.1.block.1.1.weight", "features.7.1.block.1.1.bias", "features.7.1.block.1.1.running_mean", "features.7.1.block.1.1.running_var", "features.7.1.block.1.1.num_batches_tracked", "features.7.1.block.2.fc1.weight", "features.7.1.block.2.fc1.bias", "features.7.1.block.2.fc2.weight", "features.7.1.block.2.fc2.bias", "features.7.1.block.3.0.weight", "features.7.1.block.3.1.weight", "features.7.1.block.3.1.bias", "features.7.1.block.3.1.running_mean", "features.7.1.block.3.1.running_var", "features.7.1.block.3.1.num_batches_tracked", "features.8.0.weight", "features.8.1.weight", "features.8.1.bias", "features.8.1.running_mean", "features.8.1.running_var", "features.8.1.num_batches_tracked". 

In [None]:
# Load the checkpoint
checkpoint = torch.load('best_model.pth')

# Initialize model
model_save = FusionModel(meta_input_dim=19, num_classes=7)

# Load only the model weights
model_save.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [24]:
def train_fusion_with_checkpoint(
    checkpoint_path='best_cnn_model.pth',
    meta_input_dim=19,
    num_classes=7,
    train_loader=None,
    val_loader=None,
    fusion_epochs=20,
    device='cuda'
):
    """
    Train fusion model using pre-trained CNN weights
    
    Args:
        checkpoint_path: Path to saved CNN checkpoint
        meta_input_dim: Dimension of metadata features
        num_classes: Number of output classes
        train_loader: Training data loader
        val_loader: Validation data loader
        fusion_epochs: Number of epochs for fusion training
        device: Device to use ('cuda' or 'cpu')
    """
    # 1. Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"Previous val accuracy: {checkpoint.get('val_acc', 0):.4f}")
    
    # 2. Initialize model
    model = FusionModel(meta_input_dim=meta_input_dim, num_classes=num_classes)
    model.to(device)
    
    # 3. Load CNN weights
    model.image_model.load_state_dict({
        k.replace('image_model.', ''): v 
        for k, v in checkpoint['model_state_dict'].items()
        if k.startswith('image_model.')
    })
    
    # 4. Freeze CNN and prepare for fusion training
    model.freeze_cnn()
    model.unfreeze_classifier()
    
    # 5. Set up optimizer (only for unfrozen layers)
    optimizer = torch.optim.Adam([
        {'params': model.meta_net.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=1e-3, weight_decay=1e-4)
    
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0.0
    
    # 6. Fusion training loop
    for epoch in range(fusion_epochs):
        model.train()
        train_loss = 0.0
        all_preds = []
        all_targets = []
        
        for images, metadata, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{fusion_epochs}"):
            images = images.to(device)
            metadata = metadata.to(device).float()
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images, metadata)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
        
        # Calculate metrics
        train_loss /= len(train_loader)
        train_acc = accuracy_score(all_targets, all_preds)
        train_f1 = f1_score(all_targets, all_preds, average='weighted')
        
        # Validation
        val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, device)
        
        print(f"\nEpoch {epoch+1}/{fusion_epochs}")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
            }, 'best_fusion_model.pth')
            print("Saved new best fusion model")
    
    print("\nFusion training complete!")
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    
    return model

# Helper validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for images, metadata, targets in val_loader:
            images = images.to(device)
            metadata = metadata.to(device).float()
            targets = targets.to(device)
            
            outputs = model(images, metadata)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = accuracy_score(all_targets, all_preds)
    val_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    return val_loss, val_acc, val_f1

In [25]:
fusion_model = train_fusion_with_checkpoint(
    checkpoint_path='best_model.pth',
    train_loader=train_loader,
    val_loader=val_loader,
    fusion_epochs=20
)

Loaded checkpoint from epoch 9
Previous val accuracy: 0.8500


Epoch 1/20: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]



Epoch 1/20
Train Loss: 0.2141 | Acc: 0.9612 | F1: 0.9612
Val Loss: 0.5704 | Acc: 0.8419 | F1: 0.8428
Saved new best fusion model


Epoch 2/20: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]



Epoch 2/20
Train Loss: 0.0942 | Acc: 0.9743 | F1: 0.9742
Val Loss: 0.6112 | Acc: 0.8486 | F1: 0.8493
Saved new best fusion model


Epoch 3/20: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]



Epoch 3/20
Train Loss: 0.0746 | Acc: 0.9769 | F1: 0.9768
Val Loss: 0.6722 | Acc: 0.8426 | F1: 0.8415


Epoch 4/20: 100%|██████████| 109/109 [00:43<00:00,  2.53it/s]



Epoch 4/20
Train Loss: 0.0756 | Acc: 0.9770 | F1: 0.9770
Val Loss: 0.6904 | Acc: 0.8332 | F1: 0.8371


Epoch 5/20: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]



Epoch 5/20
Train Loss: 0.0784 | Acc: 0.9760 | F1: 0.9760
Val Loss: 0.6914 | Acc: 0.8399 | F1: 0.8407


Epoch 6/20: 100%|██████████| 109/109 [00:43<00:00,  2.51it/s]



Epoch 6/20
Train Loss: 0.0698 | Acc: 0.9782 | F1: 0.9781
Val Loss: 0.6645 | Acc: 0.8439 | F1: 0.8433


Epoch 7/20: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]



Epoch 7/20
Train Loss: 0.0686 | Acc: 0.9769 | F1: 0.9769
Val Loss: 0.7350 | Acc: 0.8346 | F1: 0.8343


Epoch 8/20: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]



Epoch 8/20
Train Loss: 0.0636 | Acc: 0.9794 | F1: 0.9794
Val Loss: 0.7336 | Acc: 0.8392 | F1: 0.8411


Epoch 9/20: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]



Epoch 9/20
Train Loss: 0.0686 | Acc: 0.9753 | F1: 0.9752
Val Loss: 0.7076 | Acc: 0.8433 | F1: 0.8433


Epoch 10/20: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]



Epoch 10/20
Train Loss: 0.0608 | Acc: 0.9796 | F1: 0.9796
Val Loss: 0.7255 | Acc: 0.8406 | F1: 0.8390


Epoch 11/20: 100%|██████████| 109/109 [00:43<00:00,  2.49it/s]



Epoch 11/20
Train Loss: 0.0668 | Acc: 0.9776 | F1: 0.9776
Val Loss: 0.7355 | Acc: 0.8446 | F1: 0.8462


Epoch 12/20: 100%|██████████| 109/109 [00:44<00:00,  2.46it/s]



Epoch 12/20
Train Loss: 0.0610 | Acc: 0.9796 | F1: 0.9796
Val Loss: 0.7619 | Acc: 0.8426 | F1: 0.8408


Epoch 13/20: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]



Epoch 13/20
Train Loss: 0.0653 | Acc: 0.9794 | F1: 0.9794
Val Loss: 0.7455 | Acc: 0.8439 | F1: 0.8443


Epoch 14/20: 100%|██████████| 109/109 [00:43<00:00,  2.53it/s]



Epoch 14/20
Train Loss: 0.0690 | Acc: 0.9786 | F1: 0.9786
Val Loss: 0.7638 | Acc: 0.8359 | F1: 0.8368


Epoch 15/20: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]



Epoch 15/20
Train Loss: 0.0524 | Acc: 0.9826 | F1: 0.9826
Val Loss: 0.7560 | Acc: 0.8386 | F1: 0.8395


Epoch 16/20: 100%|██████████| 109/109 [00:43<00:00,  2.52it/s]



Epoch 16/20
Train Loss: 0.0530 | Acc: 0.9819 | F1: 0.9819
Val Loss: 0.7575 | Acc: 0.8419 | F1: 0.8435


Epoch 17/20: 100%|██████████| 109/109 [00:43<00:00,  2.50it/s]



Epoch 17/20
Train Loss: 0.0497 | Acc: 0.9832 | F1: 0.9832
Val Loss: 0.7801 | Acc: 0.8453 | F1: 0.8448


Epoch 18/20: 100%|██████████| 109/109 [00:42<00:00,  2.55it/s]



Epoch 18/20
Train Loss: 0.0680 | Acc: 0.9790 | F1: 0.9790
Val Loss: 0.7731 | Acc: 0.8426 | F1: 0.8416


Epoch 19/20: 100%|██████████| 109/109 [00:44<00:00,  2.46it/s]



Epoch 19/20
Train Loss: 0.0766 | Acc: 0.9743 | F1: 0.9742
Val Loss: 0.7711 | Acc: 0.8366 | F1: 0.8394


Epoch 20/20: 100%|██████████| 109/109 [00:44<00:00,  2.47it/s]



Epoch 20/20
Train Loss: 0.0575 | Acc: 0.9825 | F1: 0.9825
Val Loss: 0.8222 | Acc: 0.8386 | F1: 0.8409

Fusion training complete!
Best validation accuracy: 0.8486


In [26]:
# Load the checkpoint
checkpoint2 = torch.load('best_fusion_model.pth')

# Initialize model
model_save2 = FusionModel(meta_input_dim=19, num_classes=7)

# Load only the model weights
model_save2.load_state_dict(checkpoint2['model_state_dict'])
results = test_model(model_save2, test_loader, class_names)

Testing: 100%|██████████| 24/24 [00:08<00:00,  2.88it/s]



=== PER-CLASS METRICS ===

Class: nv
Accuracy: 0.6648
Precision: 0.6223
Recall: 0.6648
F1: 0.6429
AUC: 0.9327

Class: mel
Accuracy: 0.9070
Precision: 0.9135
Recall: 0.9070
F1: 0.9102
AUC: 0.9403

Class: bkl
Accuracy: 0.5556
Precision: 0.4545
Recall: 0.5556
F1: 0.5000
AUC: 0.9632

Class: bcc
Accuracy: 0.5731
Precision: 0.5833
Recall: 0.5731
F1: 0.5782
AUC: 0.8956

Class: akiec
Accuracy: 0.7692
Precision: 0.8333
Recall: 0.7692
F1: 0.8000
AUC: 0.9926

Class: vasc
Accuracy: 0.6842
Precision: 0.7831
Recall: 0.6842
F1: 0.7303
AUC: 0.9808

Class: df
Accuracy: 0.6038
Precision: 0.5079
Recall: 0.6038
F1: 0.5517
AUC: 0.9274

=== OVERALL METRICS ===
Accuracy: 0.8117

Macro Averages:
Precision: 0.6712
Recall: 0.6797
F1: 0.6733
AUC: 0.9475

Weighted Averages:
Precision: 0.8155
Recall: 0.8117
F1: 0.8131
AUC: 0.9375
