In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import numpy as np
import random
import sys
sys.path.append('tools')

import os
import argparse
from dataset import get_dataset, get_handler, get_wa_handler
from torchvision import transforms
import torch
import csv
import time

import query_strategies
import models
from utils import print_log
import kaggle_data_utility
import dataset
import pandas as pd
from tqdm import tqdm
import base_model
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks
import data_utility, annotation_utility

In [3]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset, SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes


class SwaV(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SwaVProjectionHead(2048, 2048, 128)
        self.prototypes = SwaVPrototypes(128, 2048, 1)

        self.start_queue_at_epoch = 35
        self.queues = nn.ModuleList([MemoryBankModule(size=256) for _ in range(2)])

    def forward(self, high_resolution, low_resolution, epoch):
        self.prototypes.normalize()

        high_resolution_features = [self._subforward(x) for x in high_resolution]
        low_resolution_features = [self._subforward(x) for x in low_resolution]

        high_resolution_prototypes = [
            self.prototypes(x, epoch) for x in high_resolution_features
        ]
        low_resolution_prototypes = [
            self.prototypes(x, epoch) for x in low_resolution_features
        ]
        queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)

        return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes

    def _subforward(self, input):
        features = self.backbone(input).flatten(start_dim=1)
        features = self.projection_head(features)
        features = nn.functional.normalize(features, dim=1, p=2)
        return features

    @torch.no_grad()
    def _get_queue_prototypes(self, high_resolution_features, epoch):
        if len(high_resolution_features) != len(self.queues):
            raise ValueError(
                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
            )

        # Get the queue features
        queue_features = []
        for i in range(len(self.queues)):
            _, features = self.queues[i](high_resolution_features[i], update=True)
            # Queue features are in (num_ftrs X queue_length) shape, while the high res
            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
            features = torch.permute(features, (1, 0))
            queue_features.append(features)

        # If loss calculation with queue prototypes starts at a later epoch,
        # just queue the features and return None instead of queue prototypes.
        if self.start_queue_at_epoch > 0 and epoch < self.start_queue_at_epoch:
            return None

        # Assign prototypes
        queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
        return queue_prototypes


In [4]:
import torchvision.transforms as T
from torch.utils.data import Dataset
class RNS_Active(Dataset):
    def __init__(self, data, label, transform=None, astensor=True):
        self.data = data
        self.label = label
        self.transform = transform['transform']
        print('data loaded')

        self.label = self.label[np.newaxis].T

        self.length = len(self.data)

        print(data.shape)
        print(label.shape)

        if astensor:
            self.augmentation = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.RandomApply([T.ColorJitter()], p=0.5),
                T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
                T.RandomInvert(p=0.2),
                T.RandomPosterize(4, p=0.2),
                T.ToTensor()
            ])

            self.totensor = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.ToTensor()
            ])
        else:
            self.augmentation = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.RandomApply([T.ColorJitter()], p=0.5),
                T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
                T.RandomInvert(p=0.2),
                T.RandomPosterize(4, p=0.2),
            ])

            self.totensor = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
            ])

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]

        if self.transform:
            concat_len = data.shape[1] / 4
            channel_index = np.arange(4)
            np.random.shuffle(channel_index)
            channel_index = channel_index * concat_len + (concat_len - 1) / 2
            channel_index = np.repeat(channel_index, concat_len)
            concate_len_1 = (concat_len - 1) / 2
            a_repeat = np.arange(-concate_len_1, concate_len_1 + 1)[np.newaxis].T
            base_repeat = np.repeat(a_repeat, 4, axis=1).T.flatten()
            channel_index = channel_index + base_repeat
            data = data[channel_index.astype(int)]
            data = torch.from_numpy(data).clone()
            data = data.repeat(3, 1, 1)
            data = self.augmentation(data)

        else:
            concat_len = data.shape[1] / 4
            channel_index = np.arange(4)
            # np.random.shuffle(channel_index)
            channel_index = channel_index * concat_len + (concat_len - 1) / 2
            channel_index = np.repeat(channel_index, concat_len)
            concate_len_1 = (concat_len - 1) / 2
            a_repeat = np.arange(-concate_len_1, concate_len_1 + 1)[np.newaxis].T
            base_repeat = np.repeat(a_repeat, 4, axis=1).T.flatten()
            channel_index = channel_index + base_repeat
            data = data[channel_index.astype(int)]
            data = torch.from_numpy(data).clone()
            data = data.repeat(3, 1, 1)
            data = self.totensor(data)

        return data, torch.from_numpy(label).to(dtype=torch.long), index

In [5]:
def sigmoid_focal_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
        alpha: float = 0.25,
        gamma: float = 2,
        reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (Tensor): A float tensor of arbitrary shape.
                The predictions for each example.
        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha (float): Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default: ``0.25``.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples. Default: ``2``.
        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                ``'none'``: No reduction will be applied to the output.
                ``'mean'``: The output will be averaged.
                ``'sum'``: The output will be summed. Default: ``'none'``.
    Returns:
        Loss tensor with the reduction option applied.
    """
    # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

    p = torch.sigmoid(inputs)

    # print(inputs.size())
    # print(targets.size())

    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    # Check reduction option and return loss accordingly
    if reduction == "none":
        pass
    elif reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()
    else:
        raise ValueError(
            f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
        )
    return loss
import torch.nn.functional as F
import sklearn

class ActiveLearning(pl.LightningModule):
    def __init__(self, backbone, unfreeze_backbone_at_epoch=100):
        super().__init__()
        self.backbone = backbone
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 8)
        self.fc4 = nn.Linear(8, 2)
        self.softmax = nn.Softmax(dim=1)
        self.alpha = 0.5
        self.gamma = 8
        self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch

    def training_step(self, batch, batch_idx):
        x, y = batch
        if self.current_epoch < self.unfreeze_backbone_at_epoch:
            self.backbone.eval()
            x = self.backbone(x)
            with torch.no_grad():
                x = x.view(-1, 2048)
        else:
            x = self.backbone(x)
            x = x.view(-1, 2048)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        pred = self.softmax(pred)
        label = F.one_hot(y, num_classes=2).squeeze()
        loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=self.alpha, gamma=self.gamma, reduction='mean')
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.backbone(x)
        x = x.view(-1, 2048)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        pred = self.softmax(pred)
        label = F.one_hot(y, num_classes=2).squeeze()
        loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=self.alpha, gamma=self.gamma, reduction='mean')
        out = torch.argmax(pred, dim=1)
        out = out.detach().cpu().numpy()
        target = y.squeeze().detach().cpu().numpy()
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(out, target,labels = [0,1],zero_division=0)
        acc = sklearn.metrics.accuracy_score(out, target)
        # print(acc)
        # print(precision)
        # print(recall)
        # print(fscore)
        # Logging to TensorBoard (if installed) by default
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        self.log("val_precision", precision[1])
        self.log("val_recall", recall[1])
        return pred, label

    def predict_step(self, batch, batch_idx):
        x, y = batch
        emb = self.backbone(x)
        emb = emb.view(-1, 2048)
        x = F.relu(self.fc1(emb))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        # Logging to TensorBoard (if installed) by default
        return pred, y

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

In [6]:
random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    # True ensures the algorithm selected by CUFA is deterministic
    torch.backends.cudnn.deterministic = True
    # torch.set_deterministic(True)
    # False ensures CUDA select the same algorithm each time the application is run
    torch.backends.cudnn.benchmark = False

In [7]:
raw_annotations = pd.read_csv('full_updated_anns_annotTbl_cleaned.csv')
ids = list(np.unique(raw_annotations[raw_annotations['descriptions'].notnull()]['HUP_ID']))
data_import = data_utility.read_files(path='data/rns_data', path_data='rns_raw_cache', patientIDs=ids,
                                      verbose=True)  # Import data with annotation
annotations = annotation_utility.read_annotation(annotation_path='full_updated_anns_annotTbl_cleaned.csv',
                                                 data=data_import, n_class=3)
annot = annotations.annotations
patient_list = list(np.unique(annot['Patient_ID']))
# patient_list = ['RNS026', 'HUP159', 'HUP129', 'HUP096', 'HUP182']

clip_dict = annotation_utility.combine_annot_index(annot,patient_list, 42)

window_len = 1
stride = 1
concat_n = 4
for id in tqdm(clip_dict.keys()):
    data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
    data_import[id].set_concatenation_parameter(concatenate_window_n=concat_n)
    window_indices, _ = data_import[id].get_windowed_data(clip_dict[id][0], clip_dict[id][1])
    import_label = np.array([])
    for i, ind in enumerate(window_indices):
        import_label = np.hstack((import_label, np.repeat(clip_dict[id][2][i], len(ind))))
    data_import[id].normalize_windowed_data()
    _, concatenated_data = data_import[id].get_concatenated_data(data_import[id].windowed_data, arrange='channel_stack')
    assert import_label.shape[0] == concatenated_data.shape[0]
    np.save('rns_test_cache/' + id + '.npy', {'data': concatenated_data, 'label': import_label})

100%|██████████| 18/18 [00:10<00:00,  1.64it/s]


(2, 25)
(2, 20)
(2, 59)
(2, 39)
(2, 4)
(2, 60)
(2, 21)
(2, 41)
(2, 44)
(2, 0)
(2, 71)
(2, 99)
(2, 24)
(2, 0)
(2, 40)
(2, 0)
(2, 73)
(2, 51)


100%|██████████| 15/15 [00:12<00:00,  1.18it/s]


In [8]:
def get_data(file_names, split=0.7):
    file_name_temp = file_names[0]
    cache = np.load('rns_test_cache/' + file_name_temp, allow_pickle=True)
    temp_file = cache.item().get('data')

    train_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    train_label = np.array([])
    test_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    test_label = np.array([])

    for name in tqdm(file_names):
        cache = np.load('rns_test_cache/' + name, allow_pickle=True)
        data = cache.item().get('data')
        label = cache.item().get('label')
        split_n = int(data.shape[0] * (split))
        train_data = np.vstack((train_data, data[:split_n]))
        train_label = np.hstack((train_label, label[:split_n]))
        test_data = np.vstack((test_data, data[split_n:]))
        test_label = np.hstack((test_label, label[split_n:]))

    return train_data, train_label, test_data, test_label

In [9]:
data_list = os.listdir('rns_test_cache')

X_train, y_train, X_test, y_test  = get_data(data_list, split=0.8)
# data, label,_,_ = get_data(data_list, split=1)
# train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(data, label, test_size=0.8, random_state=42)

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

100%|██████████| 15/15 [00:17<00:00,  1.14s/it]

(84503, 249, 36)
(84503,)
(21132, 249, 36)
(21132,)





In [10]:
y_test.sum()

5320.0

In [11]:
nStart = 1
nEnd = 20
nQuery = 2

n_pool = len(y_train)
n_test = len(y_test)

save_file_name = 'rns_active_lc'

In [12]:
NUM_INIT_LB = int(nStart * n_pool / 100)
NUM_QUERY = int(nQuery * n_pool / 100) if nStart != 100 else 0
NUM_ROUND = int((int(nEnd * n_pool / 100) - NUM_INIT_LB) / NUM_QUERY) if nStart != 100 else 0
if NUM_QUERY != 0:
    if (int(nEnd * n_pool / 100) - NUM_INIT_LB) % NUM_QUERY != 0:
        NUM_ROUND += 1

print(NUM_INIT_LB)
print(NUM_QUERY)
print(NUM_ROUND)

845
1690
10


In [13]:
idxs_lb = np.zeros(n_pool, dtype=bool)
idxs_tmp = np.arange(n_pool)
np.random.shuffle(idxs_tmp)
idxs_lb[idxs_tmp[:NUM_INIT_LB]] = True

In [14]:
ckpt = torch.load("rns_ckpt/checkpoint31.pth")
resnet = torchvision.models.resnet50()
backbone = nn.Sequential(*list(resnet.children())[:-1])
swav = SwaV(backbone)
swav.load_state_dict(ckpt['model_state_dict'])
model = ActiveLearning(swav.backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

idxs_train = np.arange(n_pool)[idxs_lb]

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='val_loss', filename=save_file_name+'_round_0-{epoch:02d}-{val_loss:.5f}', dirpath=save_file_name + '_ckpt')
csv_logger = pl_loggers.CSVLogger(save_file_name + '_log', name="logger_round_0")
trainer = pl.Trainer( logger=csv_logger, max_epochs=30, callbacks=[checkpoint_callback],accelerator='gpu', devices=1,log_every_n_steps=5)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [15]:
from copy import deepcopy
modelstate = deepcopy(model.state_dict())

In [16]:
def collate_fn(batch):
    info = list(zip(*batch))
    data = info[0]
    label = info[1]


    return torch.stack(data), torch.stack(label)
transforms_param = {'transform_tr': {'transform': True},
                    'transform_te': {'transform': False},
                    }

train_data = RNS_Active(X_train[idxs_train],y_train[idxs_train],transform=transforms_param['transform_tr'])
test_data = RNS_Active(X_test,y_test,transform=transforms_param['transform_te'])
train_dataloader = torch.utils.data.DataLoader(train_data,
                                        batch_size=128,
                                        shuffle=True,
                                        collate_fn=collate_fn,
                                        drop_last=True, )
val_dataloader = torch.utils.data.DataLoader(
    test_data,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | backbone | Sequential | 23.5 M
1 | fc1      | Linear     | 1.0 M 
2 | fc2      | Linear     | 32.8 K
3 | fc3      | Linear     | 520   
4 | fc4      | Linear     | 18    
5 | softmax  | Softmax    | 0     
----------------------------------------
24.6 M    Trainable params
0         Non-trainable params
24.6 M    Total params
98.362    Total estimated model params size (MB)


data loaded
(845, 249, 36)
(845,)
data loaded
(21132, 249, 36)
(21132,)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [17]:
def random_query(idxs_lb, n):
    inds = np.where(idxs_lb==0)[0]
    return inds[np.random.permutation(len(inds))][:n]

def entropy_query(X_train, y_train, trainer, model, idxs_lb, n):
    idxs_unlabeled = np.arange(n_pool)[~idxs_lb]
    untrained_data = RNS_Active(X_train[idxs_unlabeled], y_train[idxs_unlabeled], transform=transforms_param['transform_te'])
    untrained_dataloader = torch.utils.data.DataLoader(untrained_data,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   collate_fn=collate_fn,
                                                   drop_last=True, )
    predictions = trainer.predict(model,untrained_dataloader)

    probs = []
    m = nn.Softmax(dim=1)
    for pred, y in predictions:
        out = m(pred)
        probs.append(out)
    probs = torch.vstack(probs)
    # print(probs)
    log_probs = torch.log(probs)
    # print(log_probs)
    U = (probs*log_probs).sum(1)
    # print(U.sort())
    # print(U.sort()[1][:n])
    # print(y_train[idxs_unlabeled[U.sort()[1][:n]][::-1]])
    return idxs_unlabeled[U.sort()[1][:n]]

def lease_conf_query(X_train, y_train, trainer, model, idxs_lb, n):
    idxs_unlabeled = np.arange(n_pool)[~idxs_lb]
    untrained_data = RNS_Active(X_train[idxs_unlabeled], y_train[idxs_unlabeled], transform=transforms_param['transform_te'])
    untrained_dataloader = torch.utils.data.DataLoader(untrained_data,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   collate_fn=collate_fn,
                                                   drop_last=True, )
    predictions = trainer.predict(model,untrained_dataloader)
    output_list = []
    m = nn.Softmax(dim=1)
    for pred, y in predictions:
        out = m(pred)
        output_list.append(out)
    probs = torch.vstack(output_list)
    # print(probs)
    U = probs.max(1)[0]
    # print(U)
    # print(y_train[idxs_unlabeled[U.sort()[1][:n]][::-1]])
    return idxs_unlabeled[U.sort()[1][:n]]

In [1]:
for rd in range(1, NUM_ROUND + 1):
    print('Round {}/{}'.format(rd, NUM_ROUND), flush=True)
    labeled = len(np.arange(n_pool)[idxs_lb])
    if NUM_QUERY > int(nEnd * n_pool / 100) - labeled:
        NUM_QUERY = int(nEnd * n_pool / 100) - labeled

    output = lease_conf_query(X_train, y_train, trainer, model, idxs_lb, NUM_QUERY)

        # entropy_query(X_train, y_train, trainer, model, idxs_lb, NUM_QUERY)

    idxs_lb_previous = deepcopy(idxs_lb)
    # output = random_query(idxs_lb, NUM_QUERY)
    q_idxs = output
    idxs_lb_previous[q_idxs] = True
    idxs_lb = idxs_lb_previous
    print(len(np.arange(n_pool)[idxs_lb]))

    idxs_train = np.arange(n_pool)[idxs_lb]
    train_data = RNS_Active(X_train[idxs_train], y_train[idxs_train], transform=transforms_param['transform_tr'])
    test_data = RNS_Active(X_test, y_test, transform=transforms_param['transform_te'])
    train_dataloader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   collate_fn=collate_fn,
                                                   drop_last=True, )
    val_dataloader = torch.utils.data.DataLoader(
        test_data,
        batch_size=128,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=True,
    )
    model.load_state_dict(modelstate)
    checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='val_loss', filename=save_file_name+'_round_' + str(
        rd) + '-{epoch:02d}-{val_loss:.5f}', dirpath=save_file_name + '_ckpt')
    csv_logger = pl_loggers.CSVLogger(save_file_name + '_log', name="logger_round_" + str(rd))
    trainer = pl.Trainer(logger=csv_logger, max_epochs=30, callbacks=[checkpoint_callback], accelerator='gpu', devices=1,log_every_n_steps=5)
    trainer.fit(model, train_dataloader, val_dataloader)

NameError: name 'NUM_ROUND' is not defined