<a href="https://colab.research.google.com/github/ramirog034/TQx/blob/main/Deep_Learning_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Clone GitHub Repo

In [None]:
# clone the repo
!git clone https://github.com/QuIIL/TQx

# Mount Google Drive

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ensure a shortcut to the shared Bladder folder exists on you Drive
# copy the files into the TQx Bladder folder
!cp '/content/drive/MyDrive/Bladder/all_img_features_sorted.pkl' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/all_img_features_sorted_test.pkl' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/all_img_features_sorted_train.pkl' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/all_img_features_sorted_valid.pkl' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/img_info_test.txt' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/img_info_train.txt' '/content/TQx/results/Bladder'
!cp '/content/drive/MyDrive/Bladder/img_info_valid.txt' '/content/TQx/results/Bladder'

# Imports

In [None]:
import pickle
import numpy as np
import torch
import os
import pandas as pd
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import csv
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, cohen_kappa_score, confusion_matrix
from sklearn.manifold import TSNE
from matplotlib.ticker import FuncFormatter
from sklearn.cluster import KMeans
from matplotlib import cm
from sklearn.preprocessing import StandardScaler
from PIL import Image
from IPython import display

# Softmax Function

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))  # Subtracting np.max(x) for numerical stability
    return 100* e_x / e_x.sum(axis=0)

# Mappings

In [None]:
def config_for_dataset(dataset):
    if dataset in ['colon-1', 'Colon']:
        mapping = {'moderately differentiated cancer': 'MD',
                    'poorly differentiated cancer': 'PD',
                    'benign': 'BN',
                    'well differentiated cancer': 'WD'}
        mapping_2 = {'moderately differentiated cancer': 2,
                    'poorly differentiated cancer': 3,
                    'benign': 0,
                    'well differentiated cancer': 1}
        mapping_3 = {2: 'MD',
                    3: 'PD',
                    0: 'BN',
                    1: 'WD'}
        custom_order=['BN', 'WD', 'MD', 'PD']
    elif dataset in ['luad', 'WSSS4LUAD']:
        mapping = {'tumor': 'TUM',
                    'normal': 'NOR'}
        mapping_2 = {'tumor': 1,
                    'normal': 0}
        mapping_3 = {1:'TUM',
                    0:'NOR'}
        custom_order=['NOR', 'TUM']
    elif dataset in ['bach', 'BACH']:
        mapping = {'invasive carcinoma': 'IVS',
                   'in situ carcinoma': 'SITU',
                    'benign': 'BN',
                    'normal': 'NOR'}
        mapping_2 = {'invasive carcinoma': 3,
                   'in situ carcinoma': 2,
                    'benign': 1,
                    'normal': 0}
        mapping_3 = {1:'BN',
                     2: 'SITU',
                     3: 'IVS',
                    0:'NOR'}
        custom_order=['NOR', 'BN', 'SITU', 'IVS']
    elif dataset in ['bladder', 'Bladder']:
        mapping = {'high grade cancer': 'HIGH',
                   'low grade cancer': 'LOW',
                    'normal': 'NOR'}
        mapping_2 = {'high grade cancer': 2,
                     'low grade cancer': 1,
                    'normal': 0}
        mapping_3 = {2: 'HIGH',
                    1:'LOW',
                    0:'NOR'}
        custom_order=['NOR', 'LOW', 'HIGH']
    return mapping, mapping_2, mapping_3, custom_order

# Clustering

In [None]:
%cd TQx/

In [None]:
def clustering(args, features, img_label, path, postfix):
    features = torch.tensor(features).numpy()

    y = img_label
    mapping = config_for_dataset(args.dataset)[1]
    y = [mapping[i] for i in y]

    label_mapping = config_for_dataset(args.dataset)[2]
    cluster_labels_list = []

    kmeans = KMeans(n_clusters=len(label_mapping), init='k-means++', n_init='auto')
    cluster_labels = kmeans.fit_predict(features)
    cluster_labels_list.append(cluster_labels)

    return cluster_labels_list

def scale_list(original_list, scale_factor):
    scaled_list = []
    for item in original_list:
        scaled_list.extend([item] * scale_factor)
    return scaled_list

def save_string_to_file(string, file_path):
    with open(file_path, 'w') as file:
        file.write(string)

def analyze_clustering(args, raw_df, path, postfix):
    raw_df = raw_df[raw_df['match_rank'] < args.top_analyze]
    df = raw_df[raw_df['match_rank']==1]
    df['label'] = df['label'].map(args.mapping)
    k = df['label'].nunique()
    labels = df['label'].unique()
    custom_order = args.custom_order
    custom_values = df['label'].value_counts()

    fig, axes = plt.subplots(1, k, figsize=(3*k, 2.3))
    s = ''

    if args.dataset in ['colon-1', 'Colon']:
        display = 'Colon'
    elif args.dataset in ['bladder', 'Bladder']:
        display = 'Bladder'
    elif args.dataset in ['bach', 'BACH']:
        display = 'BACH'
    elif args.dataset in ['luad', 'WSSS4LUAD']:
        display = 'WSSS4LUAD'

    green = '#55a868'
    orange = '#dd8452'
    red = '#c44f52'
    blue = '#4c72b0'

    colors = [green, blue, orange, red]

    for cluster in range(k):
        entities_counts = raw_df[raw_df[f'k = {k}'] == cluster]['entity'].value_counts().head(20).to_string()
        s += entities_counts + '\n'
        value_counts = df[df[f'k = {k}'] == cluster]['label'].value_counts().reindex(custom_order)

        value_counts = value_counts / df[df[f'k = {k}'] == cluster]['label'].value_counts().sum()
        # value_counts = value_counts / value_counts.index.map(custom_values)
        ax = axes[cluster]

        df2 = value_counts.to_frame().reset_index()
        df2.columns = ['Category', 'Count']

        sns.barplot(ax=ax,
                    x='Category',
                    y='Count',
                    data=df2,
                    palette=colors,
                    width=1/2
                    )

        ax.set_title(f'Cluster {cluster+1}')
        ax.title.set_size(17)
        ax.tick_params(axis='x', labelsize=13)
        ax.set_xlabel('')
        ax.set_ylabel('')

    # Function to format tick labels
    def format_ticks(x, pos):
        return "{:.1f}".format(x)

    # Apply formatting to both axes
    # plt.gca().xaxis.set_major_formatter(FuncFormatter(format_ticks))
    for ax in axes.flat:
        ax.yaxis.set_major_formatter(FuncFormatter(format_ticks))

    plt.tight_layout()
    plt.savefig(f'{path}/k-{k}{postfix}-within.png', dpi=600)
    file_path = f'{path}/common_terms-within.txt'
    save_string_to_file(s, file_path)

def calculate_sim(args):
    ###########################################################################
    # added the if else block to locally define display
    # Corrected logic to ensure 'display' is always assigned based on args.dataset
    if args.dataset in ['colon-1', 'Colon']:
        display = 'Colon'
    elif args.dataset in ['bladder', 'Bladder']:
        display = 'Bladder'
    elif args.dataset in ['bach', 'BACH']:
        display = 'BACH'
    elif args.dataset in ['luad', 'WSSS4LUAD']:
        display = 'WSSS4LUAD'

    print(f'Clustering {args.dataset}')
    postfix = args.postfix

    ###########################################################################
    # changed args.dataset to display
    entity_name_path = 'entity.csv'
    entity_feature_path = 'entity_ALL_FEATURES.pkl'
    image_feature_path = f'results/{display}/all_img_features_sorted{postfix}.pkl'
    image_info_path = f'results/{display}/img_info{postfix}.txt'

    filters = args.filters
    path = f'results/{display}/{filters}_{args.top_freq_words}_{args.top_freq_features_to_combine}'

    entity_name = pd.read_csv(entity_name_path)
    with open(entity_feature_path, 'rb') as f:
        entity_feature = pickle.load(f)
    with open(image_feature_path, 'rb') as f:
        image_feature = pickle.load(f)
    with open(image_info_path, 'r') as file:
        img_files = file.readlines()

    # FILTER 1 using semantic name
    if len(args.filter_semantic_name) != 0:
        mask = entity_name['semantic_name'].isin(args.filter_semantic_name)
        indices = entity_name.index[mask]
        entity_name = entity_name.iloc[indices].reset_index(drop=True)
        entity_feature = torch.index_select(entity_feature, 0, torch.tensor(indices))

    norm_entity_feature = entity_feature / entity_feature.norm(dim=-1, keepdim=True)
    norm_image_feature = image_feature/ image_feature.norm(dim=-1, keepdim=True)

    sim = norm_image_feature @ norm_entity_feature.T

    # FILTER 2 using top freq words
    sorted_indices = torch.argsort(sim, dim=1, descending=True)
    ranks = torch.zeros_like(sorted_indices, dtype=torch.float)
    for i in range(sim.shape[0]):  # Iterate over rows
        ranks[i, sorted_indices[i, :]] = torch.arange(sim.shape[1], dtype=torch.float) + 1
    ranks = torch.sum(ranks, dim=0)/sim.shape[0]  # avg rank of each word
    entity_name['avg_ranking'] = ranks.numpy()
    entity_name = entity_name.nsmallest(args.top_freq_words, 'avg_ranking')
    indices = entity_name.index
    entity_feature = torch.index_select(entity_feature, 0, torch.tensor(indices))
    entity_name = entity_name.reset_index(drop=True)

    norm_entity_feature = entity_feature / entity_feature.norm(dim=-1, keepdim=True)
    sim = norm_image_feature @ norm_entity_feature.T

    image_text_representation = torch.matmul(sim, entity_feature)/torch.sum(sim,dim=1)[:,None]

    if not os.path.exists(f'{path}'):
        os.makedirs(f'{path}')
    with open(f'{path}/image_text_representation{postfix}.pkl', 'wb') as file:
        pickle.dump(image_text_representation, file)

    top_values, top_indices = torch.topk(sim, k=args.top_freq_features_to_combine, dim=1) # shape (num_img, top_entities)
    m = torch.nn.Softmax(dim=1)
    softmax_output = m(sim)
    top_values_2, top_indices_2 = torch.topk(softmax_output, k=args.top_freq_features_to_combine, dim=1)

    if args.dataset == 'gastric':
        split = '.jpg,'
    elif args.dataset in ['luad', 'WSSS4LUAD']:
        split = '.png,'
    else:
        split = ','

    img_path = [file.split(split)[0] for file in img_files]
    img_path = scale_list(img_path, args.top_freq_features_to_combine)

    raw_label = [file.split(split)[1][:-2] for file in img_files]
    label = scale_list(raw_label, args.top_freq_features_to_combine)

    match_rank = [i for i in range(args.top_freq_features_to_combine)]
    match_rank = match_rank * len(img_files)

    top_idx = pd.Series(top_indices.view(-1))
    entity = top_idx.map(entity_name['entity_name'].to_dict())
    entity_semantic = top_idx.map(entity_name['semantic_name'].to_dict())

    probability = top_values.view(-1)
    cosine_sim = top_values_2.view(-1)

    if not os.path.exists(path):
        os.makedirs(path)
    cluster_labels_list = clustering(args, image_text_representation, raw_label, path, postfix)[0]

    columns = ["image_path", "label", "match_rank", "entity", "entity_semantic", "probability", 'cosine_sim']
    df = pd.DataFrame(columns=columns)
    df['image_path'] = img_path
    df['match_rank'] = match_rank
    df['label'] = label
    df['entity'] = entity
    df['entity_semantic'] = entity_semantic
    df['probability'] = probability
    df['cosine_sim'] = cosine_sim
    filters = ('_').join(args.filter_semantic_name).replace(' ','_').replace(';','')
    df[f"k = {df['label'].nunique()}"] = np.repeat(cluster_labels_list, args.top_freq_features_to_combine)
    df_csv = df[df['match_rank'] < args.top_words_in_csv]
    df_csv.to_csv(f'{path}/top_{args.top_words_in_csv}{args.postfix}.csv', index=False)

    analyze_clustering(args, df, path, postfix)
    return cluster_labels_list

def measure_metrics(dataset, pred, label):
    if dataset in ['colon-1', 'colon-2', 'prostate-1', 'prostate-2', 'prostate-3',
                   'gastric','kidney','liver','bladder','bach','panda']:
        acc = accuracy_score(label, pred)

        excluded_class = 0

        temp_1 = []
        temp_2 = []
        for i, l in enumerate(label):
            if l != excluded_class:
                temp_1.append(l)
                temp_2.append(pred[i])
        acc_grading = accuracy_score(temp_1, temp_2)
        f1 = f1_score(label, pred, average='macro')
        kappa = cohen_kappa_score(label, pred, labels=np.arange(len(np.unique(pred))), weights='quadratic')
        cm = confusion_matrix(label, pred, labels=[0,1,2,3])
        result = (acc, acc_grading, f1, kappa, cm)

    else:
        acc = accuracy_score(label, pred)
        f1 = f1_score(label, pred, average='macro')
        rec = recall_score(label, pred, average='macro')
        pre = precision_score(label, pred, average='macro')
        cm = confusion_matrix(label, pred, labels=[0,1,2,3])
        result = (acc, pre, f1, rec, cm)

    return result

# Classification

In [None]:
def classify(args):
    print(args.test, args.feature_type, args.filter_semantic_name)

    class MLP(nn.Module):
        def __init__(self, input_size, hidden_size, output_size):
            super(MLP, self).__init__()
            self.fc1 = nn.Linear(512, 2048)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(2048, output_size)
            self.bnorm = nn.BatchNorm1d(2048)

        def forward(self, x):
            x = self.fc1(x)
            x = self.bnorm(x)
            x = self.relu(x)
            x = self.fc2(x)
            return x

    mapping = config_for_dataset(args.dataset)[1]

    filters = args.filters
    path1 = f'results/{args.dataset}/{filters}_{args.top_freq_words}_{args.top_freq_features_to_combine}'
    path2 = f'results/{args.test}/{filters}_{args.top_freq_words}_{args.top_freq_features_to_combine}'

    if args.feature_type == 'text':
        train_feature_path = f'{path1}/image_text_representation_train.pkl'
        valid_feature_path = f'{path1}/image_text_representation_valid.pkl'
        test_feature_path = f'{path2}/image_text_representation_test.pkl'
    elif args.feature_type == 'image':
        train_feature_path = f'results/{args.dataset}/all_img_features_sorted_train.pkl'
        valid_feature_path = f'results/{args.dataset}/all_img_features_sorted_valid.pkl'
        test_feature_path = f'results/{args.test}/all_img_features_sorted_test.pkl'

    if args.test in ['colon-2', 'prostate-2', 'prostate-3', 'k16']:
        test_feature_path = test_feature_path.replace('_test', '')

    with open(train_feature_path, 'rb') as file:
        train_feature = pickle.load(file)
    image_info_path = f'results/{args.dataset}/img_info_train.txt'
    with open(image_info_path, 'r') as file:
        train_img_files = file.readlines()
    if args.dataset == 'gastric':
        split = '.jpg,'
    elif args.dataset == 'luad' or args.dataset == 'WSSS4LUAD':
        split = '.png,'
    else:
        split = ','
    train_label = [file.split(split)[1][:-2] for file in train_img_files]
    train_label_list = [mapping[label] for label in train_label]
    train_label = torch.tensor(train_label_list)

    with open(valid_feature_path, 'rb') as file:
        valid_feature = pickle.load(file)
    image_info_path = f'results/{args.dataset}/img_info_valid.txt'
    with open(image_info_path, 'r') as file:
        valid_img_files = file.readlines()
    valid_label = [file.split(split)[1][:-2] for file in valid_img_files]
    valid_label = [mapping[label] for label in valid_label]

    with open(test_feature_path, 'rb') as file:
        test_feature = pickle.load(file)
    image_info_path = f'results/{args.test}/img_info_test.txt'
    with open(image_info_path, 'r') as file:
        test_img_files = file.readlines()
    test_label = [file.split(split)[1][:-2] for file in test_img_files]
    test_label = [mapping[label] for label in test_label]

    # Define hyperparameters
    input_size = 512  # Size of the input features
    hidden_size = 4096  # Size of the hidden layer
    output_size = len(set(test_label))

    # Initialize model, loss function, and optimizer
    device = torch.device('cuda:2')

    criterion = nn.CrossEntropyLoss()


    # Training loop
    num_epochs = 300

    temp = None
    metrics_list = [
        [],[],[],[]
    ]
    for seed in args.seed:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        model = MLP(input_size, hidden_size, output_size).to(device)
        model.train()
        optimizer = optim.AdamW(model.parameters(), lr=0.01)
        max_f1 = 0
        for _ in range(num_epochs):
            optimizer.zero_grad()
            train_pred = model(train_feature.to(device))
            loss = criterion(train_pred, train_label.to(device))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # print(torch.argmax(train_pred[0]))
            train_pred = torch.argmax(train_pred, dim=1).tolist()

            val_predictions = model(valid_feature.to(device))
            val_predictions = torch.argmax(val_predictions, dim=1).tolist()

            test_predictions = model(test_feature.to(device))
            test_predictions = torch.argmax(test_predictions, dim=1).tolist()

            metrics = measure_metrics(args.dataset, val_predictions, valid_label)
            if metrics[2] > max_f1:
                max_f1 = metrics[2]
                temp = measure_metrics(args.dataset, test_predictions, test_label)[:-1]
        for i in range(4):
            metrics_list[i].append(temp[i])
    new_row = [args.test, filters, args.feature_type,
               f"{np.mean(metrics_list[0]):.4f}+{np.std(metrics_list[0]):.4f}",
               f"{np.mean(metrics_list[1]):.4f}+{np.std(metrics_list[1]):.4f}",
               f"{np.mean(metrics_list[2]):.4f}+{np.std(metrics_list[2]):.4f}",
               f"{np.mean(metrics_list[3]):.4f}+{np.std(metrics_list[3]):.4f}"]

    # File path of the existing CSV file
    file_path = 'classification.csv'

    # Open the CSV file in append mode
    with open(file_path, 'a', newline='') as file:
        # Create a CSV writer object
        writer = csv.writer(file)

        # Write the new row to the CSV file
        writer.writerow(new_row)


    print(temp)

# Cluster Plotting

In [None]:
def plotting_data_split(cluster_labels, args):
    # create t-SNE object to plot 2D graph, random state, disable progress logging
    tsne = TSNE(n_components=2,
                random_state=42, verbose=0)

    # construct path to pickled embeddings for training set
    train_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_train.pkl'
    # open in binary read mode and unpickle into train_feature tensor of 512 dims
    with open(train_feature_path, 'rb') as f:
        train_feature = pickle.load(f)
    # run t-SNE on it using fit_transform, returning 2D numpy array (samples, 2)
    tsne_train = tsne.fit_transform(train_feature)

    # same but for valid
    valid_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_valid.pkl'
    with open(valid_feature_path, 'rb') as f:
        valid_feature = pickle.load(f)
    tsne_valid = tsne.fit_transform(valid_feature)

    # same but for test
    test_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_test.pkl'
    with open(test_feature_path, 'rb') as f:
        test_feature = pickle.load(f)
    tsne_test = tsne.fit_transform(test_feature)

    # txt contain image filenames and their class labels separated by a delimiter
    img_info_train = f'results/{args.dataset}/img_info_train.txt'
    img_info_valid = f'results/{args.dataset}/img_info_valid.txt'
    img_info_test = f'results/{args.dataset}/img_info_test.txt'

    # choose proper delimiter corresponding to dataset
    if args.dataset == 'gastric':
        split = '.jpg,'
    elif args.dataset == 'luad' or args.dataset == 'WSSS4LUAD':
        split = '.png,'
    else:
        split = ','

    # select proper mappings based on dataset
    dataset_configs = config_for_dataset(args.dataset)

    # [0] is the string to acronym
    # [1] is string to int
    # [2] is int to acronym or int to str
    # [3] is the acronyms only, or custom order labels
    str_to_int_map = dataset_configs[1]
    int_to_str_map = dataset_configs[2]
    custom_order_str_labels = dataset_configs[3]

    # read all the lines from the train txt file, which contains ground truth labels labels
    with open(img_info_train, 'r') as file:
        img_files_train = file.readlines()
    # for each line, split at delimiter, gives list like ["/img_11769_1", "benign.\n"] getting rid of .jpg,
    # [1] means second token, the label and :-2 removes trailing characters. could use .strip()
    # convert it to its corresponding int
    label_train_str = [file.split(split)[1][:-2] for file in img_files_train]
    label_train_int = [str_to_int_map[label] for label in label_train_str]

    # same for valid
    with open(img_info_valid, 'r') as file:
        img_files_valid = file.readlines()
    label_valid_str = [file.split(split)[1][:-2] for file in img_files_valid]
    label_valid_int = [str_to_int_map[label] for label in label_valid_str]

    # same for test
    with open(img_info_test, 'r') as file:
        img_files_test = file.readlines()
    label_test_str = [file.split(split)[1][:-2] for file in img_files_test]
    label_test_int = [str_to_int_map[label] for label in label_test_str]

    # create 10in by 10in single subplot for visualization
    # subplots creates a Figure object which is overall container that holds canvas
    # may also create one or more Axes objects, the actual plotting areas/coordinate system inside the figure
    # 1, 1 means 1 row, 1 col so just a single Axes obj
    fig, axs = plt.subplots(1, 1, figsize=(10, 10))

    # color scheme
    green = '#55a868'
    orange = '#dd8452'
    red = '#c44f52'
    blue = '#4c72b0'
    colors = [green, orange, blue, red]
    i = 0

    # concatenate all integer labels from all splits
    all_labels_int_combined = label_train_int + label_valid_int + label_test_int

    # above uses ground truth labels. use cluster labels if not empty
    if len(cluster_labels) != 0:
        all_labels_int_combined = np.concatenate((cluster_labels["_train"], cluster_labels["_valid"], cluster_labels["_test"]))

    # set for unique ids. list so we can iterate in specific order later, sorting for predictability
    unique_class_ids = sorted(list(set(all_labels_int_combined)))

    # a handle is the thing next to the label. for us its the dots in the legend, one per class
    legend_handles = []
    legend_labels = []
    plotted_labels_for_legend = set() # To ensure each class label appears only once in the legend

    for class_int_label in unique_class_ids:
        # Get the string representation for the legend
        class_str_label = int_to_str_map[class_int_label]

        # Iterate through the data splits one at a time so that feature_tsne is current split's t-SNE data,
        # and labels_int is mathing list of labels
        # zip takes two+ iterables and combines them element-wise into pairs
        # enumerate adds index counter starting with 0 default
        # so we get 3 iterations. first one uses (0, (tsne_train, label_train_int))
        for split_type_idx, (feature_tsne, labels_int) in enumerate(zip(
            [tsne_train, tsne_valid, tsne_test],
            [label_train_int, label_valid_int, label_test_int]
        )):
            # create numpy boolean mask selecting only the rows (points) belonging to current class
            current_class_features = feature_tsne[np.array(labels_int) == class_int_label]

            # Determine color from the original cycling list
            current_color = colors[i]

            # Only add label for legend once per class
            if class_str_label not in plotted_labels_for_legend:
                # draw the points
                # Store a proxy artist for the legend, using the first color encountered for this class
                # using : for all rows (all points) and 0 is x, 1 is y
                sc = axs.scatter(current_class_features[:, 0], current_class_features[:, 1],
                                 c=current_color, alpha=0.7, label=class_str_label)
                legend_handles.append(sc)
                legend_labels.append(class_str_label)
                plotted_labels_for_legend.add(class_str_label)
            else:
                # For subsequent plots of the same class (different splits), don't add a label
                axs.scatter(current_class_features[:, 0], current_class_features[:, 1],
                                 c=current_color, alpha=0.7)
        # update i. since we do this in the outer loop, color changes per class
        i += 1

    # set title of the image based on which labels we are using
    if len(cluster_labels) != 0:
      axs.set_title(f't-SNE of Image-Text Representations for {args.dataset} cluster labels', fontsize = 24)
    else:
      axs.set_title(f't-SNE of Image-Text Representations for {args.dataset} true Labels', fontsize = 24)

    # Use the collected handles and labels to create the legend
    axs.legend(handles=legend_handles, labels=legend_labels, title='Classes', loc='best')
    # dashed grid lines for visiblity
    axs.grid(True, linestyle='--', alpha=0.6)

    # save to png and display plot
    if len(cluster_labels) != 0:
      plt.savefig(f'tsne_plot_{args.dataset}_cluster_labels.png', dpi=600)
    else:
      plt.savefig(f'tsne_plot_{args.dataset}_true_labels.png', dpi=600)
    plt.show()

# Main Function

In [None]:
def main(dataset, postfix):
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset', default=dataset)
    parser.add_argument('--test', choices=['colon-1','bach','bladder', 'luad'], default='')
    parser.add_argument('--feature_type', type=str, default='text')
    parser.add_argument('--seed', default=list(range(50)))
    parser.add_argument('--plot', choices=['tsne','umap'], default='tsne')
    parser.add_argument('--filter_semantic_name', nargs='+', default=['Neoplastic Process;'])
    parser.add_argument('--filters',default='Neoplastic_Process')
    parser.add_argument('--postfix', default=postfix)
    parser.add_argument('--text_split_range', type=int, default=1)
    parser.add_argument('--img_split_range', type=int, default=8192)
    parser.add_argument('--device', type=int, default=2)
    parser.add_argument('--draw', type=bool, default=False)
    parser.add_argument('--num_workers', type=int, default=10)

    parser.add_argument('--top_words_in_csv', type=int, default=20)
    parser.add_argument('--top_analyze', default=50)
    parser.add_argument('--top_freq_features_to_combine', type=int, default=1000)
    parser.add_argument('--top_freq_words', type=int, default=1000)
    parser.add_argument('--n_clusters', default=[2,3,4,5,6,7])
    parser.add_argument('--mode', default='0')

    parser.add_argument('--mapping', default=
                        {'moderately differentiated cancer': 'MD',
                        'poorly differentiated cancer': 'PD',
                        'benign': 'BN',
                        'well differentiated cancer': 'WD'})
    parser.add_argument('--custom_order', default=['BN', 'WD', 'MD', 'PD'])

    args, unknown = parser.parse_known_args() # Modified to handle unknown arguments

    args.mapping = config_for_dataset(args.dataset)[0]
    args.custom_order = config_for_dataset(args.dataset)[-1]

    # Calculate similarity and return class labels, args
    return calculate_sim(args), args

# T-SNE

In [None]:
# dict to store cluster labels
cluster_labels = {}
for dataset in ['BACH', 'Colon', 'WSSS4LUAD', 'Bladder']:
  for postfix in ['_train', '_valid', '_test']:
    # run the main function
    cluster_labels[postfix], args = main(dataset, postfix)

  # perform t-SNE after train, valid, and test pkl have been generated
  # first use cluster_labels as first arg to let it know to use cluster labels
  plotting_data_split(cluster_labels, args)
  # then use empty cluster_labels arg to let it know to use ground truth labels
  plotting_data_split({}, args)
  # empty out cluster_labels so that we just pass cluster_labels per dataset
  cluster_labels = {}

# Dimensionality

Each embedding is dimension 512

BACH train, test, valid counts:
torch.Size([8752, 512]),torch.Size([2674, 512]),torch.Size([2832, 512])


Colon train, test, valid counts:
torch.Size([7027, 512]),torch.Size([1242, 512]),torch.Size([1588, 512])


WSSS4LUAD train, test, valid counts:
torch.Size([10091, 512]),torch.Size([1372, 512]),torch.Size([2063, 512])

Bladder train, test, valid counts:
torch.Size([26450, 512]),torch.Size([12912, 512]),torch.Size([19204, 512])

In [None]:
# test_feature_path = "";
# args = argparse.Namespace()
# for dataset in ['BACH', 'Colon', 'WSSS4LUAD', 'Bladder']:
#   args.dataset = dataset
#   train_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_train.pkl'
#   with open(train_feature_path, 'rb') as f:
#       train_feature = pickle.load(f)

#   valid_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_valid.pkl'
#   with open(valid_feature_path, 'rb') as f:
#       valid_feature = pickle.load(f)

#   test_feature_path = f'results/{args.dataset}/Neoplastic_Process_1000_1000/image_text_representation_test.pkl'
#   with open(test_feature_path, 'rb') as f:
#       test_feature = pickle.load(f)

#   print(f'{args.dataset} train, test, valid counts:')
#   print(str(train_feature.shape) + "," + str(valid_feature.shape) + "," + str(test_feature.shape))

# Images

In [None]:
# arrays of paths for images
ground_truth_label_image_paths = [
    'tsne_plot_BACH_true_labels.png',
    'tsne_plot_Colon_true_labels.png',
    'tsne_plot_WSSS4LUAD_true_labels.png',
    'tsne_plot_Bladder_true_labels.png'
]

cluster_label_image_paths = [
    'tsne_plot_BACH_cluster_labels.png',
    'tsne_plot_Colon_cluster_labels.png',
    'tsne_plot_WSSS4LUAD_cluster_labels.png',
    'tsne_plot_Bladder_cluster_labels.png'
]

# open the images
gtl_images = [Image.open(p) for p in ground_truth_label_image_paths]
cl_images = [Image.open(p) for p in cluster_label_image_paths]

# resize to 6k by 6k
w, h = 6000, 6000

for i in range(len(gtl_images)):
  gtl_images[i] = gtl_images[i].resize((w, h))
  cl_images[i] = cl_images[i].resize((w, h))

# create blank canvas for all images
NUM_COL = 4
NUM_ROW = 2
grid_w, grid_h = NUM_ROW * w, NUM_COL * h
canvas = Image.new('RGB', (grid_w, grid_h))

# start with first (cluster label) column
for idx, img in enumerate(cl_images):
  x = 0
  y = idx * h
  canvas.paste(img, (x, y))

# then do ground truth label column
for idx, img in enumerate(gtl_images):
  x = w
  y = idx * h
  canvas.paste(img, (x, y))

canvas.save("grid.png")