In [1]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from collections import Counter
from sklearn.preprocessing import LabelEncoder
import torch

In [9]:
# --- CONFIG ---
KVASIR_IMAGE_DIR = 'Processed_Kvasir_labeled_images'
KVASIR_META_PATH = 'metadata.csv'

SEEAI_IMAGE_DIR = 'SEE_AI_project_all_images/SEE_AI_project_all_images'
SEEAI_ANNOTATION_PATH = 'all_annotation.csv'

In [10]:
# --- Load Kvasir Data ---
kvasir_df = pd.read_csv(KVASIR_META_PATH, sep=';')
kvasir_df = kvasir_df.dropna(subset=['finding_class'])

In [11]:
kvasir_df

Unnamed: 0,filename,video_id,frame_number,finding_category,finding_class,x1,y1,x2,y2,x3,y3,x4,y4
0,0728084c8da942d9_22803.jpg,0728084c8da942d9,22803,Luminal,Normal clean mucosa,,,,,,,,
1,0728084c8da942d9_22804.jpg,0728084c8da942d9,22804,Luminal,Normal clean mucosa,,,,,,,,
2,0728084c8da942d9_22805.jpg,0728084c8da942d9,22805,Luminal,Normal clean mucosa,,,,,,,,
3,0728084c8da942d9_22806.jpg,0728084c8da942d9,22806,Luminal,Normal clean mucosa,,,,,,,,
4,0728084c8da942d9_22807.jpg,0728084c8da942d9,22807,Luminal,Normal clean mucosa,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...
47243,04a78ef00c5245e0_898.jpg,04a78ef00c5245e0,898,Luminal,Angiectasia,171.0,172.0,237.0,172.0,237.0,238.0,171.0,238.0
47244,04a78ef00c5245e0_899.jpg,04a78ef00c5245e0,899,Luminal,Angiectasia,169.0,205.0,234.0,205.0,234.0,273.0,169.0,273.0
47245,04a78ef00c5245e0_900.jpg,04a78ef00c5245e0,900,Luminal,Angiectasia,174.0,240.0,242.0,240.0,242.0,305.0,174.0,305.0
47246,04a78ef00c5245e0_901.jpg,04a78ef00c5245e0,901,Luminal,Angiectasia,188.0,265.0,257.0,265.0,257.0,330.0,188.0,330.0


In [17]:
# --- Load See-AI Data ---
seeai_df = pd.read_csv(SEEAI_ANNOTATION_PATH)

In [18]:
seeai_df

Unnamed: 0,annotation_class_name,image_number,angiodysplasia,erosion,stenosis,lymphangiectasia,lymph follicle,SMT,polyp-like,bleeding,diverticulum,erythema,foreign body,vein,annotation_number
0,annotation_class_number,,0,1,2,3,4,5,6,7,8,9,10,11,
1,,1.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
2,,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
3,,3.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
4,,4.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18478,,18478.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18479,,18479.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18480,,18480.0,0,0,0,0,1,0,0,0,0,0,0,0,1.0
18481,,18481.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0


In [21]:
seeai_df = seeai_df.iloc[:,1:].drop([0,18482])

In [22]:
seeai_df

Unnamed: 0,image_number,angiodysplasia,erosion,stenosis,lymphangiectasia,lymph follicle,SMT,polyp-like,bleeding,diverticulum,erythema,foreign body,vein,annotation_number
1,1.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
2,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
3,3.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
4,4.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
5,5.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18477,18477.0,0,0,0,0,1,0,0,0,0,0,0,0,1.0
18478,18478.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18479,18479.0,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18480,18480.0,0,0,0,0,1,0,0,0,0,0,0,0,1.0


In [63]:
seeai_df['image_number'] = seeai_df['image_number'].astype(int).apply(lambda x : f"{x:05}")

In [64]:
seeai_df

Unnamed: 0,image_number,angiodysplasia,erosion,stenosis,lymphangiectasia,lymph follicle,SMT,polyp-like,bleeding,diverticulum,erythema,foreign body,vein,annotation_number
1,00001,0,0,0,0,0,0,0,0,0,0,0,0,0.0
2,00002,0,0,0,0,0,0,0,0,0,0,0,0,0.0
3,00003,0,0,0,0,0,0,0,0,0,0,0,0,0.0
4,00004,0,0,0,0,0,0,0,0,0,0,0,0,0.0
5,00005,0,0,0,0,0,0,0,0,0,0,0,0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18477,18477,0,0,0,0,1,0,0,0,0,0,0,0,1.0
18478,18478,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18479,18479,0,0,0,0,0,0,0,0,0,0,0,0,0.0
18480,18480,0,0,0,0,1,0,0,0,0,0,0,0,1.0


In [23]:
kvasir_df.finding_class.value_counts()

finding_class
Normal clean mucosa     34338
Ileocecal valve          4189
Reduced Mucosal View     2906
Pylorus                  1538
Angiectasia               866
Ulcer                     854
Foreign Body              776
Lymphangiectasia          592
Erosion                   507
Blood - fresh             446
Erythema                  159
Polyp                      55
Blood - hematin            12
Ampulla of Vater           10
Name: count, dtype: int64

In [24]:
seeai_df.columns

Index(['image_number', 'angiodysplasia', 'erosion', 'stenosis',
       'lymphangiectasia', 'lymph follicle', 'SMT', 'polyp-like', 'bleeding',
       'diverticulum', 'erythema', 'foreign body', 'vein',
       'annotation_number'],
      dtype='object')

In [26]:
# --- Standardize Labels ---
label_mapping = {
    'Blood - fresh': 'bleeding',
    'Blood - hematin': 'bleeding',
    'Angiectasia': 'angiectasia',
    'Erosion': 'erosion',
    'Foreign Body': 'foreign body',
    'Lymphangiectasia': 'lymphangiectasia',
    'Polyp': 'polyp',
    'polyp-like': 'polyp',
    'Erythema': 'erythema',
    'Ulcer': 'ulcer',
    'Lymph follicle': 'lymph follicle',
    'Ampulla of Vater': 'ampulla of vater',
    'Pylorus': 'pylorus',
    'Ileocecal valve': 'ileocecal valve',
    'Reduced Mucosal View': 'reduced mucosal view',
    'Normal clean mucosa': 'normal clean mucosa'
}

In [27]:
def load_kvasir_images(kvasir_base_dir):
    image_paths, labels = [], []

    for class_folder in os.listdir(kvasir_base_dir):
        class_path = os.path.join(kvasir_base_dir, class_folder)
        if not os.path.isdir(class_path):
            continue
        for img in os.listdir(class_path):
            if img.endswith(('.jpg', '.png')):
                img_path = os.path.join(class_path, img)
                label = label_mapping.get(class_folder.lower(), class_folder.lower())
                image_paths.append(img_path)
                labels.append(label)
    return image_paths, labels

In [88]:
def load_seeai_images(seeai_df, seeai_images_dir):
    image_paths, labels = [], []

    class_columns = seeai_df.columns.drop(['image_number', 'annotation_number'])

    for _, row in seeai_df.iterrows():
        img_filename = f"image{int(row['image_number']):05}.jpg"
        img_path = os.path.join(seeai_images_dir, img_filename)
        #print(img_path)
        if not os.path.exists(img_path):
            continue

        row_classes = row[class_columns]
        if row_classes.sum() == 0:
            # No lesions detected, assign 'normal mucosal view'
            labels.append('normal clean mucosa')
        else:
            # Pick class with highest count
            max_class = row_classes.idxmax()
            label = label_mapping.get(max_class.lower(), max_class.lower())
            labels.append(label)

        image_paths.append(img_path)

    return image_paths, labels

In [32]:
# Load from both datasets
kvasir_paths, kvasir_labels = load_kvasir_images(KVASIR_IMAGE_DIR)

In [54]:
kvasir_paths[2912:2922]

['Processed_Kvasir_labeled_images/Reduced mucosal view/dc221ccc65d34010_49367.jpg',
 'Processed_Kvasir_labeled_images/Reduced mucosal view/dc221ccc65d34010_48774.jpg',
 'Processed_Kvasir_labeled_images/Reduced mucosal view/dc221ccc65d34010_48988.jpg',
 'Processed_Kvasir_labeled_images/Reduced mucosal view/8885668afb844852_25876.jpg',
 'Processed_Kvasir_labeled_images/Reduced mucosal view/5bb1d3cc7dc64cec_21501.jpg',
 'Processed_Kvasir_labeled_images/Reduced mucosal view/dc221ccc65d34010_49096.jpg',
 'Processed_Kvasir_labeled_images/Erythema/5e59c7fdb16c4228_26599.jpg',
 'Processed_Kvasir_labeled_images/Erythema/64440803f87b4843_28161.jpg',
 'Processed_Kvasir_labeled_images/Erythema/5e59c7fdb16c4228_26132.jpg',
 'Processed_Kvasir_labeled_images/Erythema/5e59c7fdb16c4228_26145.jpg']

In [55]:
kvasir_labels[2912:2922]

['reduced mucosal view',
 'reduced mucosal view',
 'reduced mucosal view',
 'reduced mucosal view',
 'reduced mucosal view',
 'reduced mucosal view',
 'erythema',
 'erythema',
 'erythema',
 'erythema']

In [56]:
set(kvasir_labels)

{'ampulla of vater',
 'angiectasia',
 'blood - fresh',
 'blood - hematin',
 'erosion',
 'erythema',
 'foreign body',
 'ileocecal valve',
 'lymphangiectasia',
 'normal clean mucosa',
 'polyp',
 'pylorus',
 'reduced mucosal view',
 'ulcer'}

In [57]:
kvasir_labels = ['bleeding' if label in ['blood - fresh', 'blood - hematin'] else label for label in kvasir_labels]

In [58]:
set(kvasir_labels)

{'ampulla of vater',
 'angiectasia',
 'bleeding',
 'erosion',
 'erythema',
 'foreign body',
 'ileocecal valve',
 'lymphangiectasia',
 'normal clean mucosa',
 'polyp',
 'pylorus',
 'reduced mucosal view',
 'ulcer'}

In [60]:
len(set(kvasir_labels))

13

In [89]:
seeai_paths, seeai_labels = load_seeai_images(seeai_df, SEEAI_IMAGE_DIR)

In [90]:
seeai_paths[15:22]

['SEE_AI_project_all_images/SEE_AI_project_all_images/image00016.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00017.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00018.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00019.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00020.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00021.jpg',
 'SEE_AI_project_all_images/SEE_AI_project_all_images/image00022.jpg']

In [91]:
seeai_labels[15:22]

['normal clean mucosa',
 'normal clean mucosa',
 'polyp',
 'polyp',
 'polyp',
 'polyp',
 'polyp']

In [92]:
seeai_labels[18000:18020]

['lymph follicle',
 'lymph follicle',
 'lymph follicle',
 'lymph follicle',
 'lymph follicle',
 'lymph follicle',
 'polyp',
 'erosion',
 'polyp',
 'normal clean mucosa',
 'angiodysplasia',
 'lymphangiectasia',
 'lymphangiectasia',
 'polyp',
 'polyp',
 'polyp',
 'polyp',
 'polyp',
 'polyp',
 'polyp']

In [93]:
set(seeai_labels)

{'angiodysplasia',
 'bleeding',
 'diverticulum',
 'erosion',
 'erythema',
 'foreign body',
 'lymph follicle',
 'lymphangiectasia',
 'normal clean mucosa',
 'polyp',
 'smt',
 'stenosis',
 'vein'}

In [94]:
len(set(seeai_labels))

13

In [95]:
# Combine
all_image_paths = kvasir_paths + seeai_paths
all_labels = kvasir_labels + seeai_labels

In [96]:
len(all_image_paths)

65719

In [97]:
len(all_labels)

65719

In [98]:
set(all_labels)

{'ampulla of vater',
 'angiectasia',
 'angiodysplasia',
 'bleeding',
 'diverticulum',
 'erosion',
 'erythema',
 'foreign body',
 'ileocecal valve',
 'lymph follicle',
 'lymphangiectasia',
 'normal clean mucosa',
 'polyp',
 'pylorus',
 'reduced mucosal view',
 'smt',
 'stenosis',
 'ulcer',
 'vein'}

In [99]:
len(set(all_labels))

19

In [100]:
# Encode string labels to integers
le = LabelEncoder()
encoded_labels = le.fit_transform(all_labels)

In [101]:
import joblib
joblib.dump(le, 'cv_label_encoder.pkl')

['cv_label_encoder.pkl']

#### Handle Class Imbalance with Weighted Sampler

In [102]:
label_counts = Counter(encoded_labels)
total_count = sum(label_counts.values())

In [103]:
label_counts

Counter({11: 40499,
         5: 4691,
         8: 4189,
         14: 2906,
         12: 2116,
         9: 1715,
         13: 1529,
         7: 1229,
         10: 1143,
         3: 1108,
         1: 866,
         17: 854,
         2: 804,
         6: 706,
         18: 511,
         15: 453,
         16: 359,
         4: 31,
         0: 10})

In [104]:
total_count

65719

In [105]:
weights = [1.0 / label_counts[label] for label in encoded_labels]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

In [109]:
weights[47253:47260]

[2.4691967702906246e-05,
 2.4691967702906246e-05,
 0.0004725897920604915,
 0.0004725897920604915,
 0.0004725897920604915,
 0.0004725897920604915,
 0.0004725897920604915]

In [110]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

In [111]:
dataset = MedicalImageDataset(all_image_paths, encoded_labels)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)

#### Models

In [112]:
import torch.nn as nn
from torchvision import models
from torchvision.models import resnet18, resnet50, densenet121
from torchvision.models import efficientnet_b0

In [113]:
def get_embedding_model(model_name='resnet18'):
    if model_name == 'resnet18':
        base_model = resnet18(pretrained=True)
        feature_dim = base_model.fc.in_features
        base_model.fc = nn.Identity()  # Remove classifier

    elif model_name == 'resnet50':
        base_model = resnet50(pretrained=True)
        feature_dim = base_model.fc.in_features
        base_model.fc = nn.Identity()

    elif model_name == 'efficientnet_b0':
        base_model = efficientnet_b0(pretrained=True)
        feature_dim = base_model.classifier[1].in_features
        base_model.classifier = nn.Identity()

    elif model_name == 'densenet121':
        base_model = densenet121(pretrained=True)
        feature_dim = base_model.classifier.in_features
        base_model.classifier = nn.Identity()

    else:
        raise ValueError(f"Unknown model: {model_name}")

    return base_model, feature_dim

In [114]:
model_r18, feature_dim_r18 = get_embedding_model('resnet18')
# Set the model to evaluation mode (for inference)
model_r18.eval()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|███████████████████████████████████████| 44.7M/44.7M [00:00<00:00, 205MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [116]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_r18 = model_r18.to(device)

In [118]:
from tqdm import tqdm
import numpy as np
import os

In [120]:
# Lists to store embeddings and labels
embeddings_list_r18 = []
labels_list_r18 = []

In [121]:
for images, labels in tqdm(dataloader, desc="Extracting Embeddings"):
    images = images.to(device)  
    labels = labels.to(device)  

    # Forward pass through the model (without classifier)
    with torch.no_grad():
        embeddings = model_r18(images)  # Get embeddings (no classifier)
    
    # Store embeddings and labels
    embeddings_list_r18.append(embeddings.cpu().detach().numpy())  # Detach from the computation graph
    labels_list_r18.append(labels.cpu().numpy())  # Store the labels

Extracting Embeddings: 100%|████████████████| 2054/2054 [06:33<00:00,  5.22it/s]


In [122]:
# Store for future use
embeddings_array_r18 = np.concatenate(embeddings_list_r18, axis=0)
labels_array_r18 = np.concatenate(labels_list_r18, axis=0)

In [123]:
np.save('embeddings_resnet18.npy', embeddings_array_r18)
np.save('labels_resnet18.npy', labels_array_r18)
torch.save(model_r18.state_dict(), 'feature_extractor_resnet18.pth')

In [138]:
model_r50, feature_dim_r50 = get_embedding_model('resnet50')
model_r50.eval()
model_r50 = model_r50.to(device)
embeddings_list_r50 = []
labels_list_r50 = []

for images, labels in tqdm(dataloader, desc="Extracting Embeddings"):
    images = images.to(device)  
    labels = labels.to(device)  

    with torch.no_grad():
        embeddings = model_r50(images)  
    
    # Store embeddings and labels
    embeddings_list_r50.append(embeddings.cpu().detach().numpy())
    labels_list_r50.append(labels.cpu().numpy())

embeddings_array_r50 = np.concatenate(embeddings_list_r50, axis=0)
labels_array_r50 = np.concatenate(labels_list_r50, axis=0)

np.save('embeddings_resnet50.npy', embeddings_array_r50)
np.save('labels_resnet50.npy', labels_array_r50)
torch.save(model_r50.state_dict(), 'feature_extractor_resnet50.pth')

Extracting Embeddings: 100%|████████████████| 2054/2054 [06:34<00:00,  5.21it/s]


In [140]:
model_eff, feature_dim_eff = get_embedding_model('efficientnet_b0')
model_eff.eval()
model_eff = model_eff.to(device)
embeddings_list_eff = []
labels_list_eff = []

for images, labels in tqdm(dataloader, desc="Extracting Embeddings"):
    images = images.to(device)  
    labels = labels.to(device)  

    with torch.no_grad():
        embeddings = model_eff(images)  
    
    # Store embeddings and labels
    embeddings_list_eff.append(embeddings.cpu().detach().numpy())
    labels_list_eff.append(labels.cpu().numpy())

embeddings_array_eff = np.concatenate(embeddings_list_eff, axis=0)
labels_array_eff = np.concatenate(labels_list_eff, axis=0)

np.save('embeddings_efficientnetb0.npy', embeddings_array_eff)
np.save('labels_efficientnetb0.npy', labels_array_eff)
torch.save(model_eff.state_dict(), 'feature_extractor_efficientnetb0.pth')

Extracting Embeddings: 100%|████████████████| 2054/2054 [06:37<00:00,  5.16it/s]


In [141]:
model_den, feature_dim_den = get_embedding_model('densenet121')
model_den.eval()
model_den = model_den.to(device)
embeddings_list_den = []
labels_list_den = []

for images, labels in tqdm(dataloader, desc="Extracting Embeddings"):
    images = images.to(device)  
    labels = labels.to(device)  

    with torch.no_grad():
        embeddings = model_den(images)  
    
    # Store embeddings and labels
    embeddings_list_den.append(embeddings.cpu().detach().numpy())
    labels_list_den.append(labels.cpu().numpy())

embeddings_array_den = np.concatenate(embeddings_list_den, axis=0)
labels_array_den = np.concatenate(labels_list_den, axis=0)

np.save('embeddings_densenet121.npy', embeddings_array_den)
np.save('labels_densenet121.npy', labels_array_den)
torch.save(model_den.state_dict(), 'feature_extractor_densenet121.pth')

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|███████████████████████████████████████| 30.8M/30.8M [00:00<00:00, 141MB/s]
Extracting Embeddings: 100%|████████████████| 2054/2054 [06:50<00:00,  5.01it/s]
