In [None]:
# HCA97 source code with slight modifications for openmax
from pytorch_lightning.loggers import WandbLogger
import wandb

wandb.login()
wandb_logger = WandbLogger(
    project='CLIP',
    log_model='all',
    name='CLIP_anno2_open'
)

In [None]:
from typing import List

import torch as th
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
# from sklearn.model_selection import train_test_split
import pandas as pd

import src.classification as lc

from pytorch_lightning.callbacks import ModelCheckpoint, Callback, EarlyStopping

from src.experiments import ExperimentMosquitoClassifier

def callbacks() -> List[Callback]:
    return [
        ModelCheckpoint(
            # dirpath='', 
            monitor="val_f1_score",
            mode="max",
            save_top_k=2,
            save_last=False,
            save_weights_only=True,
            filename="{epoch}-{val_loss}-{val_f1_score}-{val_multiclass_accuracy}",
        ),
        EarlyStopping(monitor="val_f1_score", mode="max", patience=5),
    ]

# Open-set
CLASS_DICT = {
    "albopictus":           th.tensor([1, 0, 0, 0], dtype=th.float),
    "culex":                th.tensor([0, 1, 0, 0], dtype=th.float),
    "japonicus/koreicus":   th.tensor([0, 0, 1, 0], dtype=th.float),
    "culiseta":             th.tensor([0, 0, 0, 1], dtype=th.float)
}

class_dict = {
    "albopictus":           th.tensor([1, 0, 0, 0, 0], dtype=th.float),
    "culex":                th.tensor([0, 1, 0, 0, 0], dtype=th.float),
    "japonicus/koreicus":   th.tensor([0, 0, 1, 0, 0], dtype=th.float),
    "culiseta":             th.tensor([0, 0, 0, 1, 0], dtype=th.float),
    "mosquito":             th.tensor([0, 0, 0, 0, 1], dtype=th.float)
}

dataset = 'datacomp_xl_s13b_b90k'
aug = 'hca'
bs = 16
img_size = (224, 224)
shift_box = False

# Change if the working directory is not 'experients' folder
img_dir = "" 

# New annotation for new mos alert partition
val_annotations_csv = "../data_round_2/mosAlert_new_annotation_2/val_annotation_2.csv"
train_annotations_csv = "../data_round_2/mosAlert_new_annotation_2/train_annotation_2.csv"
test_annotations_csv = "../data_round_2/mosAlert_new_annotation_2/test_annotation_2.csv"

train_df = pd.read_csv(train_annotations_csv)
val_df = pd.read_csv(val_annotations_csv)
test_df = pd.read_csv(test_annotations_csv)
test_df = test_df.sample(frac=1).reset_index(drop=True) 


train_dataloader, _, _ = ExperimentMosquitoClassifier(".", "",
                                                       class_dict=CLASS_DICT,
                                                       class_dict_test=class_dict).get_dataloaders(
    train_df,
    val_df,
    test_df,
    dataset,
    aug,
    bs,
    img_size,
    shift_box,
)

_, val_dataloader, _ = ExperimentMosquitoClassifier(img_dir, "",
                                                       class_dict=CLASS_DICT,
                                                       class_dict_test=class_dict).get_dataloaders(
    train_df,
    val_df,
    test_df,
    dataset,
    aug,
    bs,
    img_size,
    shift_box,
)

In [None]:
for batch in train_dataloader:
    print(batch[0].shape)
    print(batch[1])
    break

In [None]:
model = lc.MosquitoClassifier(
    bs=16,
    head_version=7,
    freeze_backbones=False,
    label_smoothing=0.1,
    data_aug="hca",
    epochs=15,
    max_steps=60000,
    use_ema=True,
    n_classes=4,
    loss_func="ce",
)

th.set_float32_matmul_precision("high")
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="gpu",
    precision="16-mixed",
    max_epochs=model.hparams.epochs,
    deterministic=True,  
    callbacks=callbacks(),
)

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
wandb.finish() 