## Metric


In [None]:
import torch

from learners.metrics import MultiIoUMetric


metric = MultiIoUMetric()

metric(torch.tensor([1, 0, 1]), torch.tensor([1, 0, 1]))

metric.compute()

## Utils


In [None]:
from utils.diff_dict import diff_dict


diff_dict(
    {
        "config": {},
        "a": {"a": {"p": 1}},
        "b": 2,
        "p": (12, 34),
        "d": {"e": 4, "f": 5, "g": [6, 7, 8]},
        "g": [6, 7, 8],
    },
    {
        "config": {},
        "a": {"a": {"p": 1, "q": 0}},
        "c": 3,
        "d": {"e": 4, "f": 5, "g": [6, 9, 8, {"a": 1}]},
        "g": [6, 9],
        "p": (12, 33, 34),
    },
)

In [None]:
import os
import datetime


for item in os.listdir("logs/SL"):
    mtime = os.path.getmtime(os.path.join("logs/SL", item))
    print(datetime.datetime.fromtimestamp(mtime).isoformat())

In [None]:
def make_batch_sample_indices(
    population_size: int, sample_size: int, batch_size: int
) -> list[list[int]]:
    import random

    samples = sorted(random.sample(range(population_size), sample_size))
    population_batch_size = population_size // batch_size + 1
    batch_samples = [[] for _ in range(population_batch_size)]
    for s in samples:
        batch_samples[s // batch_size].append(s - (s // batch_size) * batch_size)
    return batch_samples


make_batch_sample_indices(100, 10, 20)

In [None]:
def make_batch_sample_indices_multi(
    iterations_batches: list[tuple[int, int]], total_samples: int
) -> list[list[int]]:
    import random

    populations = [iter * batch for iter, batch in iterations_batches]

    sum_populations = sum(populations)
    samples = [round(p * total_samples / sum_populations) for p in populations]
    while True:
        sum_samples = sum(samples)
        if sum_samples == total_samples:
            break
        index = random.randint(0, len(samples) - 1)
        samples[index] += 1 if sum_samples < total_samples else -1

    batch_samples = []
    zipped = zip(iterations_batches, populations, samples)
    for (_, batch), population, sample in zipped:
        batch_samples += make_batch_sample_indices(
            population,
            sample,
            batch,
        )

    return batch_samples


make_batch_sample_indices_multi([(5, 3), (4, 2), (10, 1)], 20)

## WandB


In [None]:
import wandb

from config.constants import WANDB_SETTINGS
from utils.wandb import wandb_login

In [None]:
def wandb_log_dataset_ref(dataset_path: str, dataset_name: str, dummy: bool = False):
    wandb_login()
    wandb.init(
        tags=["helper"],
        project=WANDB_SETTINGS["dummy_project" if dummy else "project"],
        name=f"log dataset {dataset_name}",
    )
    dataset_artifact = wandb.Artifact(dataset_name, type="dataset")
    dataset_artifact.add_reference(f"file://{dataset_path}")
    wandb.log_artifact(dataset_artifact)
    wandb.finish()


wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-train", "REFUGE-train", True)
wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-val", "REFUGE-val", True)
wandb_log_dataset_ref("D:/Penelitian/FWS/data/REFUGE-test", "REFUGE-test", True)

In [None]:
# ac = wandb.Api().artifact_collection(
#     "run_table", "pandegaaz/few-shot-weakly-seg-old/run-svgff5kf-metrics"
# )

# ac.delete()

# for art in ac.artifacts():
#     print(art.name, art.id)

# art: wandb.Artifact = ac.artifacts()[0]  # type: ignore

# print(art.name, art.aliases)

# art.download("ppp/qqq")



## Aiven


In [None]:
import os
import requests
import time
from typing import Literal

from dotenv import load_dotenv


def turn_aiven_db(state: Literal["on", "off"]):
    url = (
        "https://api.aiven.io/v1/project/few-shot-weakly-seg/service/optuna-postgres-db"
    )
    load_dotenv()
    aiven_token = os.getenv("AIVEN_API_TOKEN")
    auth_headers = {"Authorization": f"aivenv1 {aiven_token}"}

    res = requests.put(
        url,
        json={"powered": True if state == "on" else False},
        params={"allow_unclean_poweroff": "false"},
        headers=auth_headers,
    )
    if res.status_code != 200:
        raise ValueError(
            f"Failed to turn db {state}, response {res.status_code}: {res.text}"
        )

    if state == "off":
        print("Successfully turned off db")
        return

    print("Waiting for db to turn on")
    while True:
        res = requests.get(url, headers=auth_headers)
        if res.status_code != 200:
            raise ValueError(
                f"Failed to get db status, response {res.status_code}: {res.text}"
            )
        if res.json()["service"]["state"] == "RUNNING":
            break
        time.sleep(10)
    print("Successfully turned on db")


In [None]:
# turn_aiven_db("off")

## Lightning Issue

https://github.com/Lightning-AI/pytorch-lightning/issues/20095


In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard.writer import SummaryWriter


In [None]:
class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        out = self.conv(x)
        out = nn.functional.interpolate(
            out, x.size()[2:], mode="bilinear"
        )  # main error
        return out

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def on_fit_start(self):
        super().on_fit_start()
        self.log_graph()

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss(pred, y)
        return loss

    def log_graph(self, inp=None):
        if inp is None:
            inp = torch.randn(8, 3, 64, 64, device=self.device)

        self.to_onnx("model.onnx", inp, export_params=False)

        if isinstance(self.logger, TensorBoardLogger):
            self.logger.log_graph(self, inp)

        tensorboard_writer = SummaryWriter("tensorboard/manual")
        tensorboard_writer.add_graph(self, inp)
        tensorboard_writer.close()


model = SimpleModel()
train_dataset = TensorDataset(
    torch.randn(20, 3, 64, 64), torch.randint(0, 3, (20, 64, 64))
)
train_loader = DataLoader(train_dataset, batch_size=8)

# model.log_graph(torch.randn(8, 3, 64, 64, device=model.device))

trainer = Trainer(
    deterministic="warn",
    accelerator="gpu",
    max_epochs=1,
    logger=TensorBoardLogger("tensorboard", name="auto", log_graph=True),
    enable_checkpointing=False,
)
# trainer.fit(model, train_loader)

# model.log_graph(torch.randn(8, 3, 64, 64, device=model.device))


## Other
