In [20]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join("tiles", str(image_id))

val = pd.read_csv("train-no-tma.csv")

val['tile_path'] = val['image_id'].apply(lambda x: get_image_path(x))

val.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,38366,LGSC,31951,21718,False,tiles/38366
1,63298,HGSC,26067,20341,False,tiles/63298
2,54928,CC,36166,31487,False,tiles/54928
3,18813,CC,54671,32443,False,tiles/18813
4,63429,EC,67783,29066,False,tiles/63429


In [21]:
import os

def count_files(directory):
    if not os.path.exists(directory):
        return "The specified directory does not exist"
    
    if not os.path.isdir(directory):
        return "The specified path is not a directory"
    
    file_count = 0
    for _, _, files in os.walk(directory):
        file_count += len(files)
        
    return file_count

In [25]:
import timm
import torch
from torch import nn

device = "cpu"

# Load the EfficientNetV2 model
model_name = 'efficientnetv2_s'  # You can choose from different versions of EfficientNetV2 like 's', 'm', 'l'
model = timm.create_model(model_name, pretrained=False)

# Update the input size if necessary - EfficientNetV2 models can handle a range of input sizes
model.default_cfg['input_size'] = (3, 224, 224)

# Modify the classifier head to have 5 output classes
# The name of the last linear layer could be different based on the model architecture
# For EfficientNetV2 the last linear layer is named 'classifier' or 'head.fc'
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
    model.classifier = nn.Linear(model.classifier.in_features, 5)
elif hasattr(model, 'head') and hasattr(model.head, 'fc'):
    model.head.fc = nn.Linear(model.head.fc.in_features, 5)
else:
    print("The model doesn't have a single linear classifier layer as expected")
    
model = model.to(device)

state_dict = torch.load('effnetv2-scratch-non-tma-models/epoch_0_batch_2000.pth', map_location=device)
model.load_state_dict(state_dict)

model.eval()

print()




In [26]:
import torchvision.transforms as transforms

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

# Define the image transformations - normalization values are usually model-specific, these are common for EfficientNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [None]:
import os
from PIL import Image, ImageFile
import torch
import pandas as pd
import torchvision.transforms as transforms
import random
import math
import numpy as np
from scipy.stats import entropy

transform = transforms.Compose([
    transforms.ToTensor(),
])

def process_sub_images(path):
    print('combing through this many files:', count_files(path))
    predicted_index_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}
    probabilities = np.zeros(5)

    # Get all .png files from the directory
    all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]
    
    # Randomly select 100 files if there are more than 100, else select all
    sample_size = min(100, len(all_files))
    sampled_files = random.sample(all_files, sample_size)

    for image_name in sampled_files:
        image_path = os.path.join(path, image_name)
        sub_image = Image.open(image_path)
        
        inputs = transform(sub_image)
        inputs = inputs.unsqueeze(0).to(device)
        
        outputs = model(inputs)
        probs = outputs.softmax(dim=1)
        predicted_index = outputs.argmax(dim=1).item()

        predicted_index_counts[predicted_index] += 1
        probabilities += probs.cpu().detach().numpy()[0]

    index_probabilities = np.array(list(predicted_index_counts.values()))
    print(predicted_index_counts)
    print(probabilities)
    print(index_probabilities / index_probabilities.sum(), entropy(index_probabilities + 1e-8))
    print(probabilities / probabilities.sum(), entropy(probabilities))
    
    # Return label with highest probability
    return integer_to_label[probabilities.argmax()]

#     # Find the index with the highest count
#     highest_index = max(predicted_index_counts, key=predicted_index_counts.get)
    
#     # Return the label associated with the highest index
#     return integer_to_label[highest_index]


# Sort the dataframe by 'label' to ensure the order of categories.
sorted_val = val.sort_values('label')

# This will keep track of the indexes for each label type as you iterate through them.
label_indices = {label: iter(rows.index) for label, rows in sorted_val.groupby('label')}

total = 0
total_correct = 0
done = False
while not done:
    for label in ['HGSC', 'CC', 'EC', 'LGSC', 'MC']:
        try:
            # Try to get the next index for the current label
            index = next(label_indices[label])
            row = sorted_val.loc[index]
            predicted_label = process_sub_images(row.tile_path)
            is_correct = predicted_label == row.label
            total_correct += is_correct
            total += 1
            print(f"{total} Image ID: {row['image_id']} True Label: {row.label} Correct? {is_correct} Accuracy: {total_correct / total}")
        except StopIteration:
            # If there are no more items in the current label, break out of the loop.
            done = True
            break

combing through this many files: 20675
{0: 0, 1: 0, 2: 0, 3: 13, 4: 87}
[ 7.89310904 10.73271433 14.08968496 15.03374253 52.25074974]
[0.   0.   0.   0.13 0.87] 0.386386713521114
[0.07893109 0.10732714 0.14089685 0.15033742 0.52250749] 1.3401185476100776
1 Image ID: 65533 True Label: HGSC Correct? False Accuracy: 0.0
combing through this many files: 37500
{0: 0, 1: 0, 2: 0, 3: 41, 4: 59}
[10.74014418 10.7478469  16.03517925 29.29797848 33.17885116]
[0.   0.   0.   0.41 0.59] 0.6768585537461997
[0.10740144 0.10747847 0.16035179 0.29297978 0.33178851] 1.4985899012034498
2 Image ID: 12442 True Label: CC Correct? False Accuracy: 0.0
combing through this many files: 29739
{0: 1, 1: 0, 2: 0, 3: 8, 4: 91}
[ 1.71835934  3.78092766  7.05353341  6.27812879 81.16905063]
[0.01 0.   0.   0.08 0.91] 0.3339327170840958
[0.01718359 0.03780928 0.07053533 0.06278129 0.81169051] 0.7238308443139467
3 Image ID: 60936 True Label: EC Correct? False Accuracy: 0.0
combing through this many files: 17893
{0: 2, 