In [1]:
import CONST
import pandas as pd
import os
import torch
import csv
import glob
import numpy as np
from sklearn.decomposition import PCA
from PIL import Image
from torch.utils.data import Dataset, DataLoader

In [2]:
# Load all embeddings
tensors = []
for filename in os.listdir(CONST.PROCESSED_EMBEDDING_DIR):
    if filename.endswith(".pt"):  # Assuming the tensors have a .pt extension
        tensor_path = os.path.join(CONST.PROCESSED_EMBEDDING_DIR, filename)
        tensor = torch.load(tensor_path)
        tensors.append(tensor)

# Stack the tensors into a 2D tensor
tensor_stack = torch.stack(tensors)

# Perform PCA to reduce to 2 dimensions
pca = PCA(n_components=3)
pca_result = pca.fit_transform(tensor_stack)

In [8]:
# Create a dataset class
class GeneratedData(Dataset):
    def __init__(self):
        self.data = []

        self.data = [
            *self.data,
            *glob.glob(f"{CONST.FINAL_GENERATION_OUTPUT}*.jpeg"),
            *glob.glob(f"{CONST.FINAL_GENERATION_OUTPUT}*.JPG"),
            *glob.glob(f"{CONST.FINAL_GENERATION_OUTPUT}*.jpg"),
            *glob.glob(f"{CONST.FINAL_GENERATION_OUTPUT}*.png"),
        ]

        # Read the annotation file
        self.annotation = pd.read_csv(
            CONST.ANNOTATION_PROCESSED_PATH, dtype={CONST.WOUND_RULER: str}
        )

    def __getitem__(self, index):
        # File name
        file_name = os.path.splitext(os.path.basename(self.data[index]))[0]

        # Textual embeddings (annotations)
        row = self.annotation[
            self.annotation[CONST.FILE_NAME].str.contains(file_name, regex=False)
        ]
        textual_embedding = " ".join(
            [
                str(row[CONST.WOUND_RULER].values),
                str(row[CONST.WOUND_TYPE].values),
                str(row[CONST.WOUND_BED].values),
                str(row[CONST.WOUND_DEPTH].values),
                str(row[CONST.WOUND_LOCATION].values),
            ]
        )

        # Image
        image = self.data[index]

        # Numeric embeddings
        numeric_embedding = glob.glob(f"{CONST.PROCESSED_EMBEDDING_DIR}{file_name}*")[0]
        numeric_embedding = torch.load(numeric_embedding)

        return image, textual_embedding, numeric_embedding

    def __len__(self):
        return len(self.data)

generated_loader = DataLoader(dataset=GeneratedData(), batch_size=1)

In [4]:
sprite_edge_count = int(np.ceil(np.sqrt(len(generated_loader))))
w_sprite = 100 * sprite_edge_count
h_sprite = 100 * sprite_edge_count

sprite = Image.new(mode="RGBA", size=(w_sprite, h_sprite), color=(0, 0, 0, 0))
x_offset = 0
y_offset = 0

In [11]:
for idx, (image, _textual_embedding, _numeric_embedding) in enumerate(generated_loader):

    # Process the sprite
    image = Image.open(image[0]).resize((100, 100))
    if x_offset + image.width > w_sprite:
        x_offset = 0
        y_offset += image.height
    sprite.paste(image, (x_offset, y_offset))
    x_offset += image.width

    with open(CONST.FINAL_METADATA, "a") as text_emb:
        text_emb.write(f"{_textual_embedding[0]}\n")

    with open(CONST.FINAL_EMBEDDINGS, "a") as num_emb:
        csv_writer = csv.writer(num_emb, delimiter="\t")
        csv_writer.writerow(
            _numeric_embedding[0].squeeze().numpy()
        )

print("writing complete")
sprite.convert("RGB").save(CONST.FINAL_IMAGE_SPRITE, transparency=0)

writing complete


In [10]:
print(
    pca.transform(
        torch.load(
            "../resources/processed/embeddings/002-4228Z-2018-10-09-9a.pt"
        ).unsqueeze(0)
    ).flatten()
)

[-5.70845696  1.81629995  5.26043689]


In [None]:
for idx, row in self.annotation.iterrows():
    # Get file base name
    file_name, _ = os.path.splitext(row[CONST.FILE_NAME])

    # Process texual embeddings
    textual_embedding = " ".join(
        [
            "RULER:",
            str(row[CONST.WOUND_RULER]),
            "TYPE:",
            str(row[CONST.WOUND_TYPE]),
            "WOUND_BED:",
            str(row[CONST.WOUND_BED]),
            "WOUND_DEPTH:",
            str(row[CONST.WOUND_DEPTH]),
            "LOCATION:",
            str(row[CONST.WOUND_LOCATION]),
        ]
    )

    # Save embeddings
    embedding = torch.load(
        os.path.join(CONST.PROCESSED_EMBEDDING_DIR, file_name + ".pt")
    ).unsqueeze(0)
    embedding = pca.transform(embedding).flatten()

    # Save images
    image = Image.open(os.path.join(CONST.FINAL_GENERATION_OUTPUT, file_name + ".png"))
    self.data.append([textual_embedding, embedding, image])