# Sparsification


In [None]:
from data.types import SparsityValue
from tasks.optic_disc_cup.datasets import RimOneDataset

rim_one_sparsity_params: dict = {
    "point_dot_size": 5,
    "grid_dot_size": 4,
    "contour_radius_dist": 4,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 4,
    "region_compactness": 0.5,
}

rim_one_data = RimOneDataset(
    mode="train",
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=rim_one_sparsity_params,
)

sparsity_values: dict[str, SparsityValue] = {
    "point": 10,
    "grid": 20,
    "contour": 1,
    "skeleton": 1,
    "region": 1,
    "point_old": 10,
    "grid_old": 20,
}
image, mask, sparse_masks, image_filename = rim_one_data.get_data_with_sparse_all(
    0, sparsity_values
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage import feature, morphology

edge_mask = feature.canny(mask.astype(np.float64))
edge_mask_dilated = morphology.binary_dilation(edge_mask, footprint=morphology.disk(2))

print(np.unique(mask), np.unique(edge_mask), np.unique(edge_mask_dilated))
plt.imshow(edge_mask_dilated)


# File Processing


## Mixed Print and Log


In [None]:
# log_file_ori = open('outputs/protoseg_long_rimone_to_drishti/learn_log_ori.txt', 'r')
# log_file = open('outputs/protoseg_long_rimone_to_drishti/learn_log.txt', 'w')
#
# for i in range(1, 4900):
#     line = log_file_ori.readline()
#
#     if not line.startswith('INFO'):
#         continue
#
#     log_file.write(line.removeprefix('INFO:root:'))
#
# log_file_ori.close()
# log_file.close()

## Exp Name Update


In [None]:
import json
import os

from config.config_type import AllConfig


def change_exp_name(
    old_name: str,
    new_name: str,
    output_path: str,
    checkpoint_path: str,
    config_filename: str,
) -> None:
    checkpoint_old_path = os.path.join(checkpoint_path, old_name)
    if os.path.exists(checkpoint_old_path):
        os.rename(checkpoint_old_path, os.path.join(checkpoint_path, new_name))

    os.rename(os.path.join(output_path, old_name), os.path.join(output_path, new_name))

    config_filepath = os.path.join(output_path, new_name, config_filename)
    if not os.path.exists(config_filepath):
        return

    with open(config_filepath, "r") as config_file:
        config: AllConfig = json.load(config_file)
    config["learn"]["exp_name"] = new_name
    with open(config_filepath, "w") as config_file:
        json.dump(config, config_file, indent=4)


In [None]:
change_exp_name(
    "v3 RO-DR S WS 2", "v3 RO-DR S WS all-iou", "outputs", "ckpt", "config.json"
)

In [None]:
for exp_name in filter(lambda x: x.startswith(""), os.listdir("outputs")):
    change_exp_name(
        exp_name, exp_name.replace("", ""), "outputs", "ckpt", "config.json"
    )

## Prediction Image Rename


In [None]:
import os

for exp_name in os.listdir("outputs"):
    pred_filenames = list(
        filter(
            lambda x: x.startswith("("),
            os.listdir(os.path.join("outputs", exp_name, "predictions")),
        )
    )
    for pred_filename in pred_filenames:
        new_pred_filename = (
            pred_filename.replace("(", "")
            .replace(")", "")
            .replace(",", "")
            .replace("'", "")
        )
        os.rename(
            os.path.join("outputs", exp_name, "predictions", pred_filename),
            os.path.join("outputs", exp_name, "predictions", new_pred_filename),
        )

## Add Column to CSV


In [None]:
import os

import pandas as pd


In [None]:
sparsity_dict = {
    "point": [10],
    "grid": [25],
    "contour": [1],
    "skeleton": [1],
    "region": [1],
    "point_old": [10],
    "grid_old": [25],
}

In [None]:
# csv_path = 'dummy.csv'

# df = pd.read_csv(csv_path)
# df.insert(1, 'n_shots', df['sparsity_mode'].apply(lambda x: 10))
# df.insert(3, 'sparsity_value', df['sparsity_mode'].apply(lambda x: sparsity_dict[x][0]))
# df.to_csv(csv_path, index=False)

In [None]:
for exp_name in list(
    filter(
        lambda x: " PS " in x and "v3 " in x and not "all" in x, os.listdir("outputs")
    )
):
    csv_path = os.path.join("outputs", exp_name, "tuned_score.csv")
    df = pd.read_csv(csv_path)
    df.insert(1, "n_shots", df["sparsity_mode"].apply(lambda x: 10))
    df.insert(
        3, "sparsity_value", df["sparsity_mode"].apply(lambda x: sparsity_dict[x][0])
    )
    df.to_csv(csv_path, index=False)

## Delete Config Param


In [None]:
import os


def delete_config_param(
    output_path: str, exp_path: str, config_filename: str, params: list[str]
) -> None:
    config_filepath = os.path.join(output_path, exp_path, config_filename)
    if not os.path.exists(config_filepath):
        return

    with open(config_filepath, "r") as config_file:
        config = json.load(config_file)
    for param in params:
        config.pop(param, None)
    with open(config_filepath, "w") as config_file:
        json.dump(config, config_file, indent=4)

In [None]:
for exp_name in filter(lambda x: " PS" in x, os.listdir("outputs")):
    delete_config_param("outputs", exp_name, "config.json", ["weasel"])

# Logging


## GPU Usage


In [None]:
import torch

print(torch.cuda.get_device_properties(0).total_memory)
print(torch.cuda.memory_reserved(0))
print(torch.cuda.memory_allocated(0))
print(torch.cuda.memory_stats(0))

In [None]:
nvidia_smi_text = b"Thu Dec 21 07:17:44 2023       \n+-----------------------------------------------------------------------------+\n| NVIDIA-SMI 450.191.01   Driver Version: 450.191.01   CUDA Version: 11.0     |\n|-------------------------------+----------------------+----------------------+\n| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n|                               |                      |               MIG M. |\n|===============================+======================+======================|\n|   0  Tesla V100-SXM2...  On   | 00000000:8A:00.0 Off |                    0 |\n| N/A   44C    P0   199W / 300W |  10256MiB / 32510MiB |     51%      Default |\n|                               |                      |                  N/A |\n+-------------------------------+----------------------+----------------------+\n                                                                               \n+-----------------------------------------------------------------------------+\n| Processes:                                                                  |\n|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n|        ID   ID                                                   Usage      |\n|=============================================================================|\n+-----------------------------------------------------------------------------+\n"

# extract memory usage and available memory

In [None]:
import re

# command = 'nvidia-smi'
# nvidia_smi_text = subprocess.check_output(command)
[used_ram, total_ram] = re.findall(r"\b\d+MiB", str(nvidia_smi_text))
used_ram = int(used_ram[:-3])
total_ram = int(total_ram[:-3])
percent_ram = used_ram * 100 / total_ram

percent_ram, total_ram

## Error Logging


In [None]:
import logging

# import sys
# import traceback
import time

In [None]:
logging.basicConfig(
    filename="test.log",
    encoding="utf-8",
    level=logging.INFO,
    format="%(asctime)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)

In [None]:
try:
    for i in range(1000):
        logging.info(i)
        time.sleep(1)
except BaseException as e:
    # traceback.print_exception(*sys.exc_info())
    # logging.warning(traceback.format_exc())
    # logging.error(traceback.format_exc())
    logging.error("Exception:", exc_info=True, stack_info=True)
    raise e
finally:
    logging.info("End")
    logger = logging.getLogger()
    while logger.hasHandlers():
        logger.removeHandler(logger.handlers[0])

# Main Processing


## Data Loader


In [None]:
from torch.utils.data import DataLoader

from tasks.optic_disc_cup.datasets import RimOneDataset

rim_one_sparsity_params: dict = {
    "point_dot_size": 5,
    "grid_dot_size": 4,
    "contour_radius_dist": 4,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 4,
    "region_compactness": 0.5,
}

rim_one_data = RimOneDataset(
    mode="train",
    num_classes=3,
    num_shots=10,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_mode="point",
    sparsity_value=20,
    sparsity_params=rim_one_sparsity_params,
)

rim_one_loader = DataLoader(
    rim_one_data, batch_size=4, num_workers=3, shuffle=False, pin_memory=False
)


## ProtoSeg Prototypes on Different Batch Size


In [None]:
import time

import numpy as np
import torch
from torch.nn import functional


In [None]:
torch.tensor([1, 2, 3]) - torch.tensor([[1, 2, 3], [4, 5, 6]])

aa = torch.tensor([[1, 2, 3]])
bb = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(aa.shape, bb.shape)
print(aa - bb, (aa - bb).shape)

In [None]:
proto = torch.randn([30, 3, 4]).type(torch.FloatTensor)
embed = torch.randn([30, 100, 4]).type(torch.FloatTensor)
target = torch.randint(0, 3, [30, 100]).type(torch.LongTensor)

In [None]:
squared_distances_loss = torch.sum(
    (proto.unsqueeze(2) - embed.unsqueeze(1)) ** 2, dim=-1
)
loss = functional.cross_entropy(-squared_distances_loss, target, ignore_index=-1)

squared_distances_pred = torch.sum(
    (proto.unsqueeze(1) - embed.unsqueeze(2)) ** 2, dim=-1
)
pred = torch.argmin(squared_distances_pred, dim=-1)

print(squared_distances_loss.shape)
print(squared_distances_pred.shape)
print(loss)
print(pred.shape)
print(pred)

In [None]:
loss_list = []
for i in range(proto.shape[0]):
    proto_item = proto[i : i + 1, :, :]
    squared_distances_loss = torch.sum(
        (proto_item.unsqueeze(2) - embed.unsqueeze(1)) ** 2, dim=-1
    )
    loss = functional.cross_entropy(-squared_distances_loss, target, ignore_index=-1)
    loss_list.append(loss)

print(torch.mean(torch.stack(loss_list)))

In [None]:
lcm = np.lcm(proto.shape[0], embed.shape[0])
proto_repeat = proto.repeat(lcm // proto.shape[0], 1, 1)
embed_repeat = embed.repeat(lcm // embed.shape[0], 1, 1)
target_repeat = target.repeat(lcm // target.shape[0], 1)
squared_distances_loss = torch.sum(
    (proto_repeat.unsqueeze(2) - embed_repeat.unsqueeze(1)) ** 2, dim=-1
)
loss = functional.cross_entropy(-squared_distances_loss, target_repeat, ignore_index=-1)
print(loss)

In [None]:
batch_diff = proto.shape[0] - embed.shape[0]
if batch_diff > 0:
    new_embed = torch.cat([embed, embed[:batch_diff]], dim=0)
    new_target = torch.cat([target, target[:batch_diff]], dim=0)
else:
    new_embed = embed[: proto.shape[0]]
    new_target = target[: proto.shape[0]]
squared_distances_loss = torch.sum(
    (proto.unsqueeze(2) - new_embed.unsqueeze(1)) ** 2, dim=-1
)
loss = functional.cross_entropy(-squared_distances_loss, new_target, ignore_index=-1)
print(loss)

In [None]:
loss_list_1 = []
total_time_1 = 0
loss_list_2 = []
total_time_2 = 0
loss_list_3 = []
total_time_3 = 0
loss_list_4 = []
total_time_4 = 0
for i in range(1000):
    proto = torch.randn([30, 3, 4]).type(torch.FloatTensor)
    embed = torch.randn([30, 100, 4]).type(torch.FloatTensor)
    target = torch.randint(0, 3, [30, 100]).type(torch.LongTensor)

    start_1 = time.time()
    squared_distances_loss = torch.sum(
        (proto.unsqueeze(2) - embed.unsqueeze(1)) ** 2, dim=-1
    )
    loss = functional.cross_entropy(-squared_distances_loss, target, ignore_index=-1)
    loss_list_1.append(loss)
    total_time_1 += time.time() - start_1

    start_2 = time.time()
    loss_list_inner = []
    for j in range(proto.shape[0]):
        proto_item = proto[j : j + 1, :, :]
        squared_distances_loss = torch.sum(
            (proto_item.unsqueeze(2) - embed.unsqueeze(1)) ** 2, dim=-1
        )
        loss = functional.cross_entropy(
            -squared_distances_loss, target, ignore_index=-1
        )
        loss_list_inner.append(loss)
    loss_list_2.append(torch.mean(torch.stack(loss_list_inner)))
    total_time_2 += time.time() - start_2

    start_3 = time.time()
    lcm = np.lcm(proto.shape[0], embed.shape[0])
    proto_repeat = proto.repeat(lcm // proto.shape[0], 1, 1)
    embed_repeat = embed.repeat(lcm // embed.shape[0], 1, 1)
    target_repeat = target.repeat(lcm // target.shape[0], 1)
    squared_distances_loss = torch.sum(
        (proto_repeat.unsqueeze(2) - embed_repeat.unsqueeze(1)) ** 2, dim=-1
    )
    loss = functional.cross_entropy(
        -squared_distances_loss, target_repeat, ignore_index=-1
    )
    loss_list_3.append(loss)
    total_time_3 += time.time() - start_3

    start_4 = time.time()
    batch_diff = proto.shape[0] - embed.shape[0]
    if batch_diff > 0:
        new_embed = torch.cat([embed, embed[:batch_diff]], dim=0)
        new_target = torch.cat([target, target[:batch_diff]], dim=0)
    else:
        new_embed = embed[: proto.shape[0]]
        new_target = target[: proto.shape[0]]
    squared_distances_loss = torch.sum(
        (proto.unsqueeze(2) - new_embed.unsqueeze(1)) ** 2, dim=-1
    )
    loss = functional.cross_entropy(
        -squared_distances_loss, new_target, ignore_index=-1
    )
    loss_list_4.append(loss)
    total_time_4 += time.time() - start_4

print(torch.mean(torch.stack(loss_list_1)), torch.std(torch.stack(loss_list_1)))
print(torch.mean(torch.stack(loss_list_2)), torch.std(torch.stack(loss_list_2)))
print(torch.mean(torch.stack(loss_list_3)), torch.std(torch.stack(loss_list_3)))
print(torch.mean(torch.stack(loss_list_4)), torch.std(torch.stack(loss_list_4)))
print(total_time_1, total_time_2, total_time_3, total_time_4)

In [None]:
# not working, usage of mean give different distribution

embed_mean = torch.mean(embed, dim=0, keepdim=True)
new_embed = embed_mean.repeat(30, 1, 1)

squared_distances_loss = torch.sum(
    (proto.unsqueeze(2) - new_embed.unsqueeze(1)) ** 2, dim=-1
)
loss = functional.cross_entropy(-squared_distances_loss, target, ignore_index=-1)
loss

## Loss Functions


In [None]:
import time

import torch

from models.u_net import UNet
from tasks.optic_disc_cup.losses import DiscCupLoss


In [None]:
ce_loss = DiscCupLoss("ce")
random_msk = torch.randint(0, 3, [4, 256, 256]).type(torch.LongTensor)
random_pred = torch.randn([4, 3, 256, 256]).type(torch.FloatTensor)
print(
    random_msk.shape,
    random_pred.shape,
    torch.unique(random_msk, return_counts=True),
    random_pred.max(),
    random_pred.min(),
)

ce_loss(random_pred, random_msk)

In [None]:
rim_one_iterator = iter(rim_one_loader)
img, msk, sparse_msk, img_filename = next(rim_one_iterator)

In [None]:
loss, loss_2 = 0.0, 0.0

net = UNet(3, 3)
pred = net(img)
pred = pred

net.load_state_dict(torch.load("ckpt/v3 RO-DR S PS all/net.pth"))
pred_2 = net(img)
pred_2 = pred_2

In [None]:
print(img_filename)
print(img.dtype, msk.dtype, sparse_msk.dtype, pred.dtype, pred_2.dtype)
print(img.shape, msk.shape, sparse_msk.shape, pred.shape, pred_2.shape)
print(
    torch.unique(msk, return_counts=True), torch.unique(sparse_msk, return_counts=True)
)
print(pred.max(), pred.min())
print(pred_2.max(), pred_2.min())

In [None]:
start_time = time.time()
for i in range(1):
    ce_loss = DiscCupLoss("ce")
    loss = ce_loss(pred, msk)
    loss_2 = ce_loss(pred_2, msk)
print(time.time() - start_time)
print(loss, loss_2)

In [None]:
start_time = time.time()
for i in range(1):
    bce_loss = DiscCupLoss("bce")
    loss = bce_loss(pred, msk)
    loss_2 = bce_loss(pred_2, msk)
print(time.time() - start_time)
print(loss, loss_2)

In [None]:
start_time = time.time()
for i in range(1):
    iou_loss = DiscCupLoss("iou")
    loss = iou_loss(pred, msk)
    loss_2 = iou_loss(pred_2, msk)
print(time.time() - start_time)
print(loss, loss_2)

In [None]:
start_time = time.time()
for i in range(1):
    iou_bce_loss = DiscCupLoss("iou_bce")
    loss = iou_bce_loss(pred, msk)
    loss_2 = iou_bce_loss(pred_2, msk)
print(time.time() - start_time)
print(loss, loss_2)

## GuidedNet Prototypes with More Efficient Calc


In [None]:
import torch
from torch import nn

from models.u_net import UNet


In [None]:
rim_one_iterator = iter(rim_one_loader)
img, msk, sparse_msk, img_filename = next(rim_one_iterator)
img = img.cuda()
msk = msk.cuda()
sparse_msk = sparse_msk.cuda()

In [None]:
m_img = torch.vstack([img, img])
m_sparse_msk = torch.vstack([sparse_msk, sparse_msk])

In [None]:
net_image = UNet(3, 8, prototype=True).cuda()

net_mask = UNet(1, 8, prototype=True).cuda()

net_head = nn.Sequential(
    nn.Conv2d(32 * 2, 32 * 1, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(32 * 1, 2, kernel_size=1),
).cuda()

nn.init.ones_(net_head[0].weight)
nn.init.ones_(net_head[-1].weight)

net_merge = nn.AdaptiveAvgPool2d((1, 1)).cuda()

In [None]:
def one_hot_masks(ori_mask, num_classes):
    c_masks = []
    for c in range(num_classes):
        c_mask = torch.zeros_like(ori_mask).float()
        c_mask[ori_mask == c] = 1
        c_mask.unsqueeze_(1)
        c_masks.append(c_mask)
    return c_masks

In [None]:
m_sparse_msk_list = one_hot_masks(m_sparse_msk, 3)
m_img_embed = net_image(m_img)
m_sparse_msk_embed_list = [net_mask(msk) for msk in m_sparse_msk_list]

In [None]:
print(img.shape, msk.shape, sparse_msk.shape)
print(img.dtype, msk.dtype, sparse_msk.dtype)
print(m_img.shape, m_sparse_msk.shape, m_sparse_msk_list[0].shape)
print(m_img_embed.shape, m_sparse_msk_embed_list[0].shape)
print(m_img_embed.dtype, m_sparse_msk_embed_list[0].dtype)

In [None]:
no_net_mask = True

if no_net_mask:
    combined_mask = torch.zeros_like(m_sparse_msk_list[0])
    for i, mask in enumerate(m_sparse_msk_list):
        combined_mask += mask
    combined_embeddings = m_img_embed * combined_mask
else:
    combined_embeddings = torch.clone(m_img_embed)
    for mask_embedding in m_sparse_msk_embed_list:
        combined_embeddings *= mask_embedding

merged_embeddings = net_merge(combined_embeddings)
prototypes = torch.mean(merged_embeddings, dim=0, keepdim=True)
tiled_prototypes = torch.tile(prototypes, (1, 1, 256, 256))

print(
    combined_embeddings.shape,
    merged_embeddings.shape,
    prototypes.shape,
    tiled_prototypes.shape,
)
print(
    combined_embeddings.dtype,
    merged_embeddings.dtype,
    prototypes.dtype,
    tiled_prototypes.dtype,
)

# Other
