In [None]:
import os
os.environ['CUDA_MODULE_LOADING']='LAZY'

import math
import cv2
from collections import namedtuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from timm.utils import NativeScaler
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler

from rsna_utils import *

In [None]:
train_df = pd.read_csv("./dataset/train_with_box.csv")
train_df["id"] = train_df.apply(lambda x: str(x.patient_id) + "_" + str(x.laterality), axis=1)

In [None]:
device = "cuda"
lr = 5e-5
weight_decay = 0.05
momentum = 0.9
batch_size = 32
EPOCHES = 10

q1_model = Q1Net(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2)
q1_model.to(device)
load_q1_pretrained("./dataset/nextvit_base_in1k6m_384.pth", q1_model)

q2_model = Q2Net()
q2_model.to(device)

Args = namedtuple('Args', [
    "opt",
    "lr",
    "sched",
    "epochs",
    "warmup_epochs",
    "weight_decay",
    "decay_rate",
    "decay_epochs",
    "cooldown_epochs",
    "min_lr",
    "warmup_lr",
    "batch_size",
    "momentum"
])
args = Args(
    opt = "adamw",
    lr = lr,
    weight_decay = weight_decay,
    batch_size = batch_size,
    decay_rate = 0.1,
    decay_epochs = 30,
    warmup_epochs = 5,
    cooldown_epochs = 10,
    epochs = EPOCHES,
    min_lr = 1e-5,
    warmup_lr = 1e-6,
    sched = "cosine",
    momentum = momentum
)

loss_scaler = NativeScaler()
q1_optimizer = create_optimizer(args, q1_model)
q1_lr_scheduler, _ = create_scheduler(args, q1_optimizer)

q2_optimizer = create_optimizer(args, q2_model)
q2_lr_scheduler, _ = create_scheduler(args, q2_optimizer)

criterion = nn.CrossEntropyLoss()

In [None]:
inner_iterations = 12
patch_size = 384
k = 2
grad_acc_steps = 1

In [None]:
batch_patient_ids = ["106_0", "236_0", "283_1", "500_0"]
labels = [1, 1, 1, 1]
run_iteration(
    train_df,
    batch_patient_ids,
    labels,
    patch_size,
    k,
    q1_model,
    q2_model,
    criterion,
    q1_optimizer,
    q2_optimizer,
    inner_iterations,
    grad_acc_steps,
    device
)

In [None]:
image = load_image(train_df, 582, patch_size)

for patch in patch_generator(image, patch_size):
    plt.imshow(patch[0])
    plt.show()
    break

In [None]:
batch_patient_ids = ["106_0", "236_0", "283_1", "500_0"]
z_matrix, key_padding_mask = z_filling(train_df, batch_patient_ids, q1_model, patch_size, device)