In [None]:
import numpy as np
import pandas as pd
import pandas.api.types
import jax
import jax.numpy as jnp
from jax import jit
from flax import linen as nn

import h5py
import itertools
from PIL import Image
from pathlib import Path
import time

import matplotlib.pyplot as plt
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from scipy.stats import zscore
from numpy import nanmean, nanstd
np.seterr(invalid='ignore')

import copy
import os
import io

In [None]:
file_path_train_meta = '/kaggle/input/isic-2024-challenge/train-metadata.csv'
file_path_test_meta = '/kaggle/input/isic-2024-challenge/test-metadata.csv'

df_train = pd.read_csv(file_path_train_meta)
df_test = pd.read_csv(file_path_test_meta)

In [None]:
train_targets = df_train['target'].to_numpy()

# unique, counts = np.unique(train_targets, return_counts=True)
# count_dict = dict(zip(unique, counts))
# print(count_dict, train_targets.shape)

In [None]:
def filter_columns(df: pd.DataFrame, keep_columns: list) -> pd.DataFrame:
    """
    Remove columns from a DataFrame that are not in the specified list.
    
    :param df: Input DataFrame
    :param keep_columns: List of column names to keep
    :return: DataFrame with only the specified columns
    """
    # Find the intersection of existing columns and the keep_columns list
    columns_to_keep = [col for col in keep_columns if col in df.columns]
    
    # Return the DataFrame with only the specified columns
    return df[columns_to_keep]

test_cols = df_test.columns
print(test_cols)

filtered_df_train = filter_columns(df_train, test_cols)


In [None]:
cat_cols = ['patient_id','age_approx', 'sex', 'anatom_site_general', 
            'image_type', 'tbp_tile_type', 'tbp_lv_location', 
            'tbp_lv_location_simple', 'attribution', 'copyright_license']

In [None]:
filtered_df_train = filtered_df_train.copy()

# Encode categories
category_encoder = OrdinalEncoder(
    categories='auto',
    dtype=int,
    handle_unknown='use_encoded_value',
    unknown_value=-2,
    encoded_missing_value=-1,
)

# Transform training data
X_cat_train = category_encoder.fit_transform(filtered_df_train[cat_cols])
for c, cat_col in enumerate(cat_cols):
    filtered_df_train.loc[:, cat_col] = X_cat_train[:, c]

# Transform test data
X_cat_test = category_encoder.transform(df_test[cat_cols])
for c, cat_col in enumerate(cat_cols):
    df_test.loc[:, cat_col] = X_cat_test[:, c]

print("Encoding complete.")

In [None]:
def normalize_dataframe_columns(df, method='minmax', exclude_columns=None):
    """
    Normalize each column of a DataFrame independently.
    
    :param df: Input DataFrame
    :param method: 'minmax' for Min-Max scaling, 'standard' for Standardization
    :param exclude_columns: List of column names to exclude from normalization
    :return: Normalized DataFrame
    """
    # Create a copy of the DataFrame to avoid modifying the original
    df_normalized = df.copy()
    
    if exclude_columns is None:
        exclude_columns = []
    
    # Select numeric columns, excluding specified columns
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    columns_to_normalize = [col for col in numeric_columns if col not in exclude_columns]
    
    if method == 'minmax':
        scaler = MinMaxScaler()
    elif method == 'standard':
        scaler = StandardScaler()
    else:
        raise ValueError("Method must be 'minmax' or 'standard'")
    
    # Normalize each column independently
    for column in columns_to_normalize:
        df_normalized[column] = scaler.fit_transform(df_normalized[[column]])
    
    return df_normalized


# Normalize using Standard scaling
df_train_minmax = normalize_dataframe_columns(filtered_df_train, method='standard')
df_test_minmax = normalize_dataframe_columns(df_test, method='standard')


In [None]:
# Convert df row data to row elements in a jax array of type float
numeric_train_data_jax = jnp.array(df_train_minmax.iloc[0:, 1:].to_numpy(dtype=float))
numeric_test_data_jax = jnp.array(df_test_minmax.iloc[0:, 1:].to_numpy(dtype=float))

# Combining image and tabular data

In [None]:
file_path_train = '/kaggle/input/isic-2024-challenge/train-image.hdf5'
file_path_test = '/kaggle/input/isic-2024-challenge/test-image.hdf5'

def hdf5_keys(image_dir):
    with h5py.File(image_dir, 'r') as f:
        dataset_names = list(f.keys())
    return dataset_names

train_keys = hdf5_keys(file_path_train)
test_keys = hdf5_keys(file_path_test)

In [None]:
# # Load and preprocess images

def load_images(image_dir, keys, numeric_df, batch, num_images=5000):
    with h5py.File(image_dir, 'r') as f:
        images = []
        for idx, name in enumerate(keys[batch*num_images:(batch*num_images)+num_images]):
            img = f[name][()]
            img = Image.open(io.BytesIO(img))
            img_resized = img.resize((128, 128))
            img_array = jnp.asarray(img_resized).flatten() 
            df_row = numeric_df[batch*num_images + idx]
            combined_array = jnp.concatenate([img_array, df_row])
            images.append(combined_array)
        return jnp.stack(images)


# train_img_stack = load_images(file_path_train, train_keys, numeric_train_data_jax, 0)
# test_img_stack = load_images(file_path_test, test_keys, numeric_test_data_jax, 0, num_images=3)

In [None]:
# JAX only model
def create_model(Winit: jax.Array, labelW_init: jax.Array, beta_init=1.):
    return {
        'W': jnp.array(Winit),
        'labelW': jnp.array(labelW_init),
        'beta': beta_init
    }

def model_call(params, x):
    """Compute the energy of the memories given a particular label"""
    assert len(x.shape) < 2, "No batch dimension"
    sim = -jnp.sum(jnp.power(params['W'] - x, 2), -1)    
    energy = -jax.nn.logsumexp(params['beta'] * sim[:, None] + params['labelW'].T, axis=1)
    return energy

def update_model(params, new_W: jax.Array, new_labelW: jax.Array):
    """Update the model with new data"""
    return {
        'W': new_W,
        'labelW': new_labelW,
        'beta': params['beta']
    }


def find_similar_samples_with_prob(params, test_images, train_targets):
    # Vectorize the model call over the batch dimension
    batched_model = jax.vmap(lambda x: model_call(params, x))
    
    # Compute energies for all test images in parallel
    energies = batched_model(test_images)

    output = []
    
    # Normalize energies to range [0, 1]
    for energy_list in energies:
        max_energy = jnp.max(energy_list)
        min_energy = jnp.min(energy_list)
        normalized_energy = jnp.sort((energy_list - min_energy) / (max_energy - min_energy))[:20]

        # Convert normalized energies to probabilities
        # Lower energy means higher similarity, so we negate the energies
        similarities = jnp.exp(-normalized_energy)
        
        # Normalize similarities to get probabilities
        probability = jnp.max(similarities / jnp.sum(similarities, axis=0, keepdims=True))    
        most_similar_idx = jnp.argmin(energy_list)
        
        output.append(list(zip([probability], [most_similar_idx])))

    return output


# Ensure the function runs on GPU
@jax.jit
def jitted_find_similar_samples_with_prob(params, test_images, train_targets):
    return find_similar_samples_with_prob(params, test_images, train_targets)

# def find_most_similar_image(params, test_images):
#     # Vectorize the model call over the batch dimension
#     batched_model = jax.vmap(lambda x: model_call(params, x))
    
#     # Compute energies for all test images in parallel
#     energies = batched_model(test_images)
    
#     # Find the most similar index for each test image
#     most_similar_idxs = jnp.argmin(energies, axis=1)
    
#     return most_similar_idxs

# # Ensure the function runs on GPU
# @jax.jit
# def jitted_find_most_similar_image(params, test_images):
#     return find_most_similar_image(params, test_images)

In [None]:
num_train = len(train_keys)
num_test = len(test_keys)

batch_size = 5000
total_train_batches = (num_train + batch_size - 1) // batch_size
total_test_batches = (num_test + batch_size - 1) // batch_size
print(total_test_batches)

In [None]:
# Initialize model
num_classes = 20000
train_img_stack = load_images(file_path_train, train_keys, numeric_train_data_jax, batch=0)

W_init = train_img_stack[0:batch_size]
labelW_init = jax.nn.one_hot(jnp.arange(batch_size), num_classes=batch_size)
model_params = create_model(W_init, labelW_init, beta_init=1.)


# Training loop on training images and numeric data, loss and epochs are irrelevant to model training
for train_batch in range(1, total_train_batches):
    start_time = time.time()
    train_img_stack = load_images(file_path_train, train_keys, numeric_train_data_jax, train_batch)
    W = train_img_stack
    Nsamples = len(W)
    
    labels = jnp.arange(Nsamples)
    labelW = jax.nn.one_hot(jnp.arange(Nsamples), num_classes=Nsamples)

    # Update the existing model with new data
    model_params = update_model(model_params, W, labelW)

    if train_batch % 10 == 0:
        end_time = time.time()
        print(f'Time for batches {train_batch-10} through {train_batch} of {total_train_batches} is {end_time - start_time}')


In [None]:
preds = np.zeros(len(test_keys))

for test_batch in range(total_test_batches):    
    test_img_stack = load_images(file_path_test, test_keys, numeric_test_data_jax, test_batch)
    num_imgs = len(test_img_stack)
    outputs = jitted_find_similar_samples_with_prob(model_params, test_img_stack, train_targets)
    for pred_idx, prob_sim_idx in enumerate(outputs):
        if train_targets[prob_sim_idx[0][1]] == 0:
            preds[(test_batch*num_imgs)+pred_idx] = prob_sim_idx[0][0]
        else:
            preds[(test_batch*num_imgs)+pred_idx] = 1 - prob_sim_idx[0][0]

In [None]:
df_sub = pd.read_csv("/kaggle/input/isic-2024-challenge/sample_submission.csv")
df_sub["target"] = preds
df_sub.to_csv("submission.csv", index=False)
df_sub