In [27]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join("tiles", str(image_id))

val = pd.read_csv("train-yes-tma.csv")

val['tile_path'] = val['image_id'].apply(lambda x: get_image_path(x))

val.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,31594,EC,3388,3388,True,tiles/31594
1,41586,CC,2964,2964,True,tiles/41586
2,91,HGSC,3388,3388,True,tiles/91
3,36583,LGSC,3388,3388,True,tiles/36583
4,35565,MC,2964,2964,True,tiles/35565


In [28]:
import os

def count_files(directory):
    if not os.path.exists(directory):
        return "The specified directory does not exist"
    
    if not os.path.isdir(directory):
        return "The specified path is not a directory"
    
    file_count = 0
    for _, _, files in os.walk(directory):
        file_count += len(files)
        
    return file_count

In [45]:
from PIL import Image
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTForImageClassification

device = "cpu"
model_name = "google/vit-base-patch16-224"
print(f"Using device {device} and model {model_name}")

processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model.classifier = nn.Linear(model.classifier.in_features, 5)
model = model.to(device)
state_dict = torch.load('vit-finetune-non-tma-models-pt-4/epoch_0_batch_10000.pth', map_location=device)
model.load_state_dict(state_dict)

Using device cpu and model google/vit-base-patch16-224


<All keys matched successfully>

In [43]:
integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

In [46]:
import os
from PIL import Image, ImageFile
import torch
import pandas as pd
import torchvision.transforms as transforms
import random
import math
import numpy as np
from scipy.stats import entropy

transform = transforms.Compose([
    transforms.ToTensor(),
])

def process_sub_images(path):
    print('combing through this many files:', count_files(path))
    predicted_index_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}
    probabilities = np.zeros(5)

    # Get all .png files from the directory
    all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]
    
    # Randomly select 100 files if there are more than 100, else select all
    sample_size = min(100, len(all_files))
    sampled_files = random.sample(all_files, sample_size)

    for image_name in sampled_files:
        image_path = os.path.join(path, image_name)
        sub_image = Image.open(image_path)
        
        inputs = processor(images=sub_image, return_tensors="pt")
        for key in inputs.keys():
            inputs[key] = inputs[key].to(device)

        outputs = model(**inputs)
        probs = outputs.logits.softmax(dim=1)
        predicted_index = outputs.logits.argmax(dim=1).item()

        predicted_index_counts[predicted_index] += 1
        probabilities += probs.cpu().detach().numpy()[0]

    index_probabilities = np.array(list(predicted_index_counts.values()))
    print(predicted_index_counts)
    print(probabilities)
    print(index_probabilities / index_probabilities.sum(), entropy(index_probabilities + 1e-8))
    print(probabilities / probabilities.sum(), entropy(probabilities))
    
    # Return label with highest probability
    return integer_to_label[probabilities.argmax()]

#     # Find the index with the highest count
#     highest_index = max(predicted_index_counts, key=predicted_index_counts.get)
    
#     # Return the label associated with the highest index
#     return integer_to_label[highest_index]


# Sort the dataframe by 'label' to ensure the order of categories.
sorted_val = val.sort_values('label')

# This will keep track of the indexes for each label type as you iterate through them.
label_indices = {label: iter(rows.index) for label, rows in sorted_val.groupby('label')}

total = 0
total_correct = 0
done = False
model.eval()
while not done:
    for label in ['HGSC', 'CC', 'EC', 'LGSC', 'MC']:
        try:
            # Try to get the next index for the current label
            index = next(label_indices[label])
            row = sorted_val.loc[index]
            predicted_label = process_sub_images(row.tile_path)
            is_correct = predicted_label == row.label
            total_correct += is_correct
            total += 1
            print(f"{total} Image ID: {row['image_id']} True Label: {row.label} Correct? {is_correct} Accuracy: {total_correct / total}")
        except StopIteration:
            # If there are no more items in the current label, break out of the loop.
            done = True
            break

combing through this many files: 265
{0: 4, 1: 22, 2: 7, 3: 26, 4: 41}
[17.64284172 17.51408403 20.29660708 19.55779859 24.98866824]
[0.04 0.22 0.07 0.26 0.41] 1.3638057144174334
[0.17642842 0.17514084 0.20296607 0.19557799 0.24988668] 1.6005458111767035
1 Image ID: 91 True Label: HGSC Correct? False Accuracy: 0.0
combing through this many files: 195
{0: 2, 1: 31, 2: 2, 3: 14, 4: 51}
[12.2030745  24.32177881 13.81616968 22.9475533  26.71142315]
[0.02 0.31 0.02 0.14 0.51] 1.13820916713412
[0.12203075 0.24321779 0.1381617  0.22947553 0.26711423] 1.56440839962102
2 Image ID: 41586 True Label: CC Correct? False Accuracy: 0.0
combing through this many files: 115
{0: 8, 1: 0, 2: 56, 3: 33, 4: 3}
[19.31303218  3.52988147 35.81410391 27.46700898 13.8759738 ]
[0.08 0.   0.56 0.33 0.03] 0.9978120545737951
[0.19313032 0.03529881 0.35814104 0.27467009 0.13875974] 1.4323433999634128
3 Image ID: 31594 True Label: EC Correct? True Accuracy: 0.3333333333333333
combing through this many files: 154
{0: 