In [1]:
import torch
import torch.nn as nn

config = {
    'image_size': 224,
    'patch_size': 32,
    'num_classes': 5,
    'dim': 768,
    'depth': 12,
    'heads': 12,
    'mlp_dim': 3072
}

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.patch_embed(x)  # [B, C, H, W]
        x = x.flatten(2)  # [B, C, H*W]
        x = x.transpose(1, 2)  # [B, H*W, C]
        return x

class Attention(nn.Module):
    def __init__(self, dim, heads=12):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, heads=heads)
        self.drop_path = nn.Dropout(p)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), out_features=dim)
        self.mlp_drop = nn.Dropout(p)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.mlp_drop(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size=config['patch_size'], emb_size=config['dim'])
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config['dim']))
        self.pos_embed = nn.Parameter(torch.zeros(1, (config['image_size'] // config['patch_size']) ** 2 + 1, config['dim']))
        self.pos_drop = nn.Dropout(p=0.1)

        self.blocks = nn.Sequential(*[Block(dim=config['dim'], heads=config['heads']) for _ in range(config['depth'])])
        self.norm = nn.LayerNorm(config['dim'], eps=1e-6)
        self.head = nn.Linear(config['dim'], config['num_classes'])

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)

        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)

        return x

In [16]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join("tiles", str(image_id))

val = pd.read_csv("train-no-tma.csv")

val['tile_path'] = val['image_id'].apply(lambda x: get_image_path(x))

val.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,38366,LGSC,31951,21718,False,tiles/38366
1,63298,HGSC,26067,20341,False,tiles/63298
2,54928,CC,36166,31487,False,tiles/54928
3,18813,CC,54671,32443,False,tiles/18813
4,63429,EC,67783,29066,False,tiles/63429


In [17]:
import os

def count_files(directory):
    if not os.path.exists(directory):
        return "The specified directory does not exist"
    
    if not os.path.isdir(directory):
        return "The specified path is not a directory"
    
    file_count = 0
    for _, _, files in os.walk(directory):
        file_count += len(files)
        
    return file_count

In [37]:
import os
from PIL import Image, ImageFile
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
import torchvision.transforms as transforms
import random
import math
import numpy as np
from scipy.stats import entropy


# Define the paths
model_weights_path = "vit-non-tma-models-pt-2/model_step_29000.pt"  # Path to the model weights

# Set up the device and the model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

model = VisionTransformer(config)
model.to(device)

# Load the weights
state_dict = torch.load(model_weights_path, map_location=device)
model.load_state_dict(state_dict, strict=False)

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

# Define a transformation to resize images to 224x224
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.8265, 0.7217, 0.8247], std=[0.1133, 0.1265, 0.0960]), # FOR NO TMA. calculated above
])

def process_sub_images(path):
    print('combing through this many files:', count_files(path))
    predicted_index_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}
    probabilities = np.zeros(5)

    # Get all .png files from the directory
    all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]
    
    # Randomly select 100 files if there are more than 100, else select all
    sample_size = min(100, len(all_files))
    sampled_files = random.sample(all_files, sample_size)

    for image_name in sampled_files:
        image_path = os.path.join(path, image_name)
        sub_image = Image.open(image_path)

        sub_image = transform(sub_image).unsqueeze(0).to(device)

        outputs = model(sub_image)
        probs = outputs.softmax(dim=1)
        predicted_index = outputs.argmax(dim=1).item()

        predicted_index_counts[predicted_index] += 1
        probabilities += probs.cpu().detach().numpy()[0]

    index_probabilities = np.array(list(predicted_index_counts.values()))
    print(predicted_index_counts)
    print(probabilities)
    print(index_probabilities / index_probabilities.sum(), entropy(index_probabilities + 1e-8))
    print(probabilities / probabilities.sum(), entropy(probabilities))
    
    # Return label with highest probability
    return integer_to_label[probabilities.argmax()]

#     # Find the index with the highest count
#     highest_index = max(predicted_index_counts, key=predicted_index_counts.get)
    
#     # Return the label associated with the highest index
#     return integer_to_label[highest_index]


# Sort the dataframe by 'label' to ensure the order of categories.
sorted_val = val.sort_values('label')

# This will keep track of the indexes for each label type as you iterate through them.
label_indices = {label: iter(rows.index) for label, rows in sorted_val.groupby('label')}

total = 0
total_correct = 0
done = False

while not done:
    for label in ['HGSC', 'CC', 'EC', 'LGSC', 'MC']:
        try:
            # Try to get the next index for the current label
            index = next(label_indices[label])
            row = sorted_val.loc[index]
            predicted_label = process_sub_images(row.tile_path)
            is_correct = predicted_label == row.label
            total_correct += is_correct
            total += 1
            print(f"{total} Image ID: {row['image_id']} True Label: {row.label} Correct? {is_correct} Accuracy: {total_correct / total}")
        except StopIteration:
            # If there are no more items in the current label, break out of the loop.
            done = True
            break


combing through this many files: 20675
{0: 15, 1: 20, 2: 28, 3: 4, 4: 33}
[19.74587366 20.86293369 24.56795904 12.36310356 22.46012924]
[0.15 0.2  0.28 0.04 0.33] 1.4574996687161268
[0.19745874 0.20862934 0.24567959 0.12363104 0.22460129] 1.5860237269354076
1 Image ID: 65533 True Label: HGSC Correct? False Accuracy: 0.0
combing through this many files: 37500
{0: 2, 1: 0, 2: 90, 3: 1, 4: 7}
[11.26259083 13.83724283 44.72092406 15.24236908 14.93687293]
[0.02 0.   0.9  0.01 0.07] 0.40526483197391594
[0.11262591 0.13837243 0.44720924 0.15242369 0.14936873] 1.4502185360090978
2 Image ID: 12442 True Label: CC Correct? False Accuracy: 0.0
combing through this many files: 29739
{0: 11, 1: 46, 2: 30, 3: 0, 4: 13}
[19.7698997  25.05053087 22.91280146 13.34640963 18.92035869]
[0.11 0.46 0.3  0.   0.13] 1.2264240350388134
[0.197699   0.25050531 0.22912801 0.1334641  0.18920359] 1.5886519462386441
3 Image ID: 60936 True Label: EC Correct? False Accuracy: 0.0
combing through this many files: 17893
{

KeyboardInterrupt: 