## Metric


In [None]:
import torch

from learners.metrics import MultiIoUMetric


metric = MultiIoUMetric()

metric(torch.tensor([1, 0, 1]), torch.tensor([1, 0, 1]))

metric.compute()

## Utils


In [None]:
from utils.diff_dict import diff_dict


diff_dict(
    {
        "config": {},
        "a": {"a": {"p": 1}},
        "b": 2,
        "p": (12, 34),
        "d": {"e": 4, "f": 5, "g": [6, 7, 8]},
        "g": [6, 7, 8],
    },
    {
        "config": {},
        "a": {"a": {"p": 1, "q": 0}},
        "c": 3,
        "d": {"e": 4, "f": 5, "g": [6, 9, 8, {"a": 1}]},
        "g": [6, 9],
        "p": (12, 33, 34),
    },
)

In [None]:
import os
import datetime


for item in os.listdir("logs/SL"):
    mtime = os.path.getmtime(os.path.join("logs/SL", item))
    print(datetime.datetime.fromtimestamp(mtime).isoformat())

In [None]:
def make_batch_sample_indices(
    population_size: int, sample_size: int, batch_size: int
) -> list[list[int]]:
    import random

    samples = sorted(random.sample(range(population_size), sample_size))
    population_batch_size = population_size // batch_size + 1
    batch_samples = [[] for _ in range(population_batch_size)]
    for s in samples:
        batch_samples[s // batch_size].append(s - (s // batch_size) * batch_size)
    return batch_samples


make_batch_sample_indices(100, 10, 20)

In [None]:
def make_batch_sample_indices_multi(
    iterations_batches: list[tuple[int, int]], total_samples: int
) -> list[list[int]]:
    import random

    populations = [iter * batch for iter, batch in iterations_batches]

    sum_populations = sum(populations)
    samples = [round(p * total_samples / sum_populations) for p in populations]
    while True:
        sum_samples = sum(samples)
        if sum_samples == total_samples:
            break
        index = random.randint(0, len(samples) - 1)
        samples[index] += 1 if sum_samples < total_samples else -1

    batch_samples = []
    zipped = zip(iterations_batches, populations, samples)
    for (_, batch), population, sample in zipped:
        batch_samples += make_batch_sample_indices(
            population,
            sample,
            batch,
        )

    return batch_samples


make_batch_sample_indices_multi([(5, 3), (4, 2), (10, 1)], 20)

## WandB


In [None]:
import wandb

from utils.wandb import wandb_login, get_wandb_project

In [None]:
def wandb_log_dataset_ref(dataset_path: str, dataset_name: str, dummy: bool = False):
    wandb_login()
    wandb.init(
        tags=["helper"],
        project=get_wandb_project(dummy),
        name=f"log dataset {dataset_name}",
    )
    dataset_artifact = wandb.Artifact(dataset_name, type="dataset")
    dataset_artifact.add_reference(f"file://{dataset_path}")
    wandb.log_artifact(dataset_artifact)
    wandb.finish()


# wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-train", "REFUGE-train", True)
# wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-val", "REFUGE-val", True)
# wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-test", "REFUGE-test", True)

In [None]:
# ac = wandb.Api().artifact_collection(
#     "run_table", "pandegaaz/few-shot-weakly-seg-old/run-svgff5kf-metrics"
# )

# ac.delete()

# for art in ac.artifacts():
#     print(art.name, art.id)

# art: wandb.Artifact = ac.artifacts()[0]  # type: ignore

# print(art.name, art.aliases)

# art.download("ppp/qqq")



In [None]:
# import os
# from utils.wandb import prepare_study_ckpt_artifact_name, prepare_ckpt_artifact_alias, wandb_log_file

# wandb_login()
# wandb.init(
#     tags=["helper"],
#     project=get_wandb_project(False),
#     name="log ckpts",
# )

# ckpts = sorted(filter(lambda x: x.endswith(".ckpt"), os.listdir(os.getcwd())))
# for ckpt in ckpts:
#     study_id = ckpt.split("=")[-1].split(".")[0]
#     artifact_name = prepare_study_ckpt_artifact_name(study_id)
#     new_ckpt = ckpt.split(" study")[0].replace("F=", "fold=") + ".ckpt"
#     artifact_alias = prepare_ckpt_artifact_alias(new_ckpt)
#     artifact_path = os.path.join(os.getcwd(), new_ckpt)
#     os.rename(ckpt, new_ckpt)
#     wandb_log_file(
#         wandb.run,
#         artifact_name,
#         artifact_path,
#         "study-checkpoint",
#         [artifact_alias],
#     )
#     print(ckpt, artifact_name, artifact_alias, new_ckpt, sep=" | ")

# wandb.finish()

In [None]:
# from utils.wandb import wandb_path

# runs = wandb.Api().runs(
#     wandb_path(False),
#     filters={"config.dataset": "RIM-ONE-3-train", "config.study": {"$ne": "8kcKT"}},
# )

# for i, run in enumerate(runs):
#     print(i, run.name)
#     run.config["study"] = "8kcKT"
#     run.update()

# runs = wandb.Api().runs(
#     wandb_path(False),
#     filters={"display_name":"2024-09-13 05-37 eRF"},
# )

# for i, run in enumerate(runs):
#     print(i, run.name, run.group)
#     run.group = "WS multi-step"
#     run.update()

In [None]:
import time
from utils.wandb import wandb_use_alert


def dummy_alert():
    with wandb_use_alert():
        print("Sleep...")
        time.sleep(30)
        raise ValueError("Dummy alert")


# wandb_login()
# wandb.init(project=get_wandb_project(True), name="check alert", group="dummy")
# dummy_alert()

## Aiven


In [None]:
import os
import requests
import time
from typing import Literal

from dotenv import load_dotenv


def turn_aiven_db(state: Literal["on", "off"]):
    url = (
        "https://api.aiven.io/v1/project/few-shot-weakly-seg/service/optuna-postgres-db"
    )
    load_dotenv()
    aiven_token = os.getenv("AIVEN_API_TOKEN")
    auth_headers = {"Authorization": f"aivenv1 {aiven_token}"}

    res = requests.put(
        url,
        json={"powered": True if state == "on" else False},
        params={"allow_unclean_poweroff": "false"},
        headers=auth_headers,
    )
    if res.status_code != 200:
        raise ValueError(
            f"Failed to turn db {state}, response {res.status_code}: {res.text}"
        )

    if state == "off":
        print("Successfully turned off db")
        return

    print("Waiting for db to turn on")
    while True:
        res = requests.get(url, headers=auth_headers)
        if res.status_code != 200:
            raise ValueError(
                f"Failed to get db status, response {res.status_code}: {res.text}"
            )
        if res.json()["service"]["state"] == "RUNNING":
            break
        time.sleep(10)
    print("Successfully turned on db")


In [None]:
# turn_aiven_db("off")

## Lightning Issue

https://github.com/Lightning-AI/pytorch-lightning/issues/20095


In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard.writer import SummaryWriter


In [None]:
class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        out = self.conv(x)
        out = nn.functional.interpolate(
            out, x.size()[2:], mode="bilinear"
        )  # main error
        return out

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def on_fit_start(self):
        super().on_fit_start()
        self.log_graph()

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss(pred, y)
        return loss

    def log_graph(self, inp=None):
        if inp is None:
            inp = torch.randn(8, 3, 64, 64, device=self.device)

        self.to_onnx("model.onnx", inp, export_params=False)

        if isinstance(self.logger, TensorBoardLogger):
            self.logger.log_graph(self, inp)

        tensorboard_writer = SummaryWriter("tensorboard/manual")
        tensorboard_writer.add_graph(self, inp)
        tensorboard_writer.close()


model = SimpleModel()
train_dataset = TensorDataset(
    torch.randn(20, 3, 64, 64), torch.randint(0, 3, (20, 64, 64))
)
train_loader = DataLoader(train_dataset, batch_size=8)

# model.log_graph(torch.randn(8, 3, 64, 64, device=model.device))

trainer = Trainer(
    deterministic="warn",
    accelerator="gpu",
    max_epochs=1,
    logger=TensorBoardLogger("tensorboard", name="auto", log_graph=True),
    enable_checkpointing=False,
)
# trainer.fit(model, train_loader)

# model.log_graph(torch.randn(8, 3, 64, 64, device=model.device))


## Study Time


In [None]:
def get_study_time(train_size: int, val_size: int, num_folds: int) -> float:
    return train_size / 100 * 4 * num_folds + val_size / 100 * 2 * num_folds + 1


print(get_study_time(45, 5, 3) * 10)
print(get_study_time(80, 20, 3) * 10)
print(get_study_time(320, 80, 3) * 10)
print(get_study_time(400, 400, 1) * 10)
print(get_study_time(45 + 80 + 320, 5 + 20 + 80, 3) * 10)

## Weight Averaging


In [9]:
import torch

from copy import deepcopy

In [10]:
ckpt0 = torch.load(
    "logs/PS/best study 3UQpU/val_score=0.8294 epoch=11 fold=0 trial=40.ckpt"
)
ckpt1 = torch.load(
    "logs/PS/best study 3UQpU/val_score=0.7772 epoch=17 fold=1 trial=40.ckpt"
)

ckpt = deepcopy(ckpt0)
weight = {
    k: (v + ckpt1["state_dict"][k]) / 2
    for k, v in ckpt0["state_dict"].items()
    if k.startswith("net.")
}
ckpt["state_dict"].update(weight)
# torch.save(ckpt, "logs/PS/best study 3UQpU/merged.ckpt")

# Resize Images


In [None]:
import os

from skimage import transform
from skimage import io


In [None]:
def resize_image(img, is_mask: bool):
    dtype = img.dtype
    s1, s2 = img.shape[:2]
    smax = max(s1, s2)
    if smax <= 1024:
        return img

    if s1 == smax:
        output_shape = (1024, round(1024 * s2 / s1))
    else:
        output_shape = (round(1024 * s1 / s2), 1024)

    if is_mask:
        order = 0
        anti_aliasing = False
    else:
        order = 1
        anti_aliasing = True

    resized = transform.resize(
        img,
        output_shape,
        order=order,
        preserve_range=True,
        anti_aliasing=anti_aliasing,
    )
    resized = resized.astype(dtype)

    return resized

In [None]:
paths = [
    "../data/ISIC2016/old_input",
    "../data/ISIC2016/old_mask",
    "../data/ISIC2017/old_input",
    "../data/ISIC2017/old_mask",
    "../data/ISIC2018/old_input",
    "../data/ISIC2018/old_mask",
]

for path in paths:
    is_mask = "mask" in path
    for item in os.listdir(path):
        filepath = os.path.join(path, item)
        img = io.imread(filepath, as_gray=is_mask)
        resized = resize_image(img, is_mask)
        io.imsave(filepath.replace("old_", ""), resized)

# Other
