# ResNet-18

In [21]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
import cv2
import torchvision.transforms as T
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.models import resnet18
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim 

### Fill the path to the data directory

In [11]:
import random 
def split_data(data_dir, train_size=0.8, val_size=0.1):
    random.seed(1234)
    data = Path(data_dir).glob('*/*')
    data = [x for x in data if x.is_file() and x.suffix != '.zip']
    random.shuffle(data)
    train_size = int(len(data) * train_size)
    val_size = int(len(data) * val_size)
    train_data = data[:train_size]
    val_data = data[train_size:train_size+val_size]
    test_data = data[train_size+val_size:]

    return train_data, val_data, test_data

_, __, test_data = split_data('data')


In [9]:
data_path = ''
model_path = Path('models/')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32

In [10]:

def get_data(data_dir):
    data = Path(data_dir).glob('*/*')
    folder_names = ['carrying', 'threat', 'normal']
    data = [x for x in data if x.is_file() and x.suffix != '.zip']
    return data

test_data = get_data(data_path)

In [12]:
class ThreatDataset(Dataset):
    def __init__(self, data, loader_type='train', transforms=None, color_space='rgb'):
        self.folder_names = ['carrying', 'threat', 'normal']
        self.data = data
        self.color_space = color_space
        self.transform = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.folder_names.index(data.parent.name)
        image = cv2.imread(str(data))
        if self.color_space == 'rgb':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        elif self.color_space == 'gray':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            # Need to broadcast the gray image to 3 channels
            image = np.dstack((image, image, image))
        elif self.color_space == 'hsv':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        elif self.color_space == 'lab':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        if self.transform:
            image = self.transform(image)
        return image, label


test_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transforms = {
    'test': test_transforms
}

rgb_test_dataset = ThreatDataset(test_data, transforms=transforms['test'], color_space='rgb')
hsv_test_dataset = ThreatDataset(test_data, transforms=transforms['test'], color_space='hsv')
lab_test_dataset = ThreatDataset(test_data, transforms=transforms['test'], color_space='lab')
gray_test_dataset = ThreatDataset(test_data, transforms=transforms['test'], color_space='gray')

rgb_test_loader = DataLoader(rgb_test_dataset, batch_size=batch_size, shuffle=False)
hsv_test_loader = DataLoader(hsv_test_dataset, batch_size=batch_size, shuffle=False)
lab_test_loader = DataLoader(lab_test_dataset, batch_size=batch_size, shuffle=False)
gray_test_loader = DataLoader(gray_test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
# THE TRAIN AND VALIDATE FUNCTIONS HAVE BEEN REMOVED, YOU CAN FIND THEM IN THE SCRIPTS DIRECTORY
class Trainer:
    def __init__(self, model, optimizer, criterion, scheduler, device):
        # The trainer uses a one-hot distribution for the labels, so we need to use the CrossEntropyLoss
        # instead of the NLLLoss
        # Using FCC layer as the last layer, we can try to use basic loss functions like MSE or L1

        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.best_acc = 1/3
        self.train_acc_arr = []
        self.val_acc_arr = []
        self.train_losses = []
        self.val_losses = []
        self.test_acc = 0
        self.test_loss = 0
        if (device == 'cuda') and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

    def test(self, test_loader, name='model_final'):
        self.model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            for i, (x, y) in tqdm(enumerate(test_loader)):
                x = x.to(self.device)
                y_label = y
                y = F.one_hot(y, num_classes=3).to(self.device).float()
                total += y.size(0)
                y_pred = self.model(x)
                loss = self.criterion(y_pred, y)

                _, predicted = torch.max(y_pred.data, 1)
                correct += (predicted.cpu() == y_label).sum().item()
                if i % 100 == 0:
                    print(f'Test Loss: {loss.item()}')
        print(f'Accuracy: {100 * correct / total}')
        self.test_acc = correct/total
        self.test_loss = loss
        # self.save_all(name=name)

    def save_model(self, path):
        torch.save(self.model.state_dict(), f'models/{path}.pth')

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))


In [14]:
model = resnet18(pretrained=False)
model.fc = nn.Linear(512, 3)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

model.load_state_dict(torch.load(model_path/'resnet18_gray_29.pth',  map_location=torch.device(device)))

trainer = Trainer(model, optimizer, criterion, scheduler, device)
trainer.test(gray_test_loader)
# print(model)

1it [00:04,  4.37s/it]

Test Loss: 0.95920729637146


16it [00:55,  3.50s/it]

Accuracy: 64.64646464646465





In [15]:
model = resnet18(pretrained=False)
model.fc = nn.Linear(512, 3)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

model.load_state_dict(torch.load(model_path/'resnet18_rgb_25.pth',  map_location=torch.device(device)))

trainer = Trainer(model, optimizer, criterion, scheduler, device)
trainer.test(rgb_test_loader)

1it [00:03,  3.79s/it]

Test Loss: 0.961929440498352


16it [00:51,  3.22s/it]

Accuracy: 66.86868686868686





In [16]:
model = resnet18(pretrained=False)
model.fc = nn.Linear(512, 3)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

model.load_state_dict(torch.load(model_path/'resnet18_hsv_46.pth',  map_location=torch.device(device)))

trainer = Trainer(model, optimizer, criterion, scheduler, device)
trainer.test(hsv_test_loader)

1it [00:03,  3.86s/it]

Test Loss: 0.8658714294433594


16it [00:50,  3.15s/it]

Accuracy: 63.03030303030303





In [17]:
model = resnet18(pretrained=False)
model.fc = nn.Linear(512, 3)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

model.load_state_dict(torch.load(model_path/'resnet18_lab_37.pth',  map_location=torch.device(device)))

trainer = Trainer(model, optimizer, criterion, scheduler, device)
trainer.test(lab_test_loader)

1it [00:04,  4.02s/it]

Test Loss: 0.7498617768287659


16it [00:50,  3.13s/it]

Accuracy: 64.24242424242425





# Bi-FPN

This model is very memory intensive, so smaller batch sizes have to be used. Please refrain from running this on CPU.
GBs of memory required = batch_size * 1 GB

In [23]:
from efficientdet.model import BiFPN
from efficientnet_pytorch.efficientnet import EfficientNet_Head
from efficientdet.utils import Anchors

In [44]:
batch_size = 8
num_workers = 4

In [19]:

class FFTDataset(Dataset):
    def __init__(self, data, transforms = None):
        self.folder_names = ['carrying', 'threat', 'normal']
        self.data = data
        self.transforms = transforms

    def __getitem__(self, idx):
        #         im_path = self.path[idx]
        #         label = self.path[idx].split('/')[-3]
        #         label = 1 if label == 'real' else 0
        data = self.data[idx]
        label = self.folder_names.index(data.parent.name)
        img = cv2.imread(str(data))
        ft_sample = self.generate_FT(img)
        ft_sample = cv2.resize(ft_sample, (80, 80))
        ft_sample = torch.from_numpy(ft_sample).float()
        ft_sample = torch.unsqueeze(ft_sample, 0)
        img = cv2.resize(img, (512, 512))
        if self.transforms:
            img = self.transforms(img)
#             im = np.round(im, 2)
        return img, label, ft_sample

    def generate_FT(self, image):
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        f = np.fft.fft2(image)
        fshift = np.fft.fftshift(f)
        fimg = np.log(np.abs(fshift)+1)
        maxx = -1
        minn = 100000
        for i in range(len(fimg)):
            if maxx < max(fimg[i]):
                maxx = max(fimg[i])
            if minn > min(fimg[i]):
                minn = min(fimg[i])
        fimg = (fimg - minn+1) / (maxx - minn+1)
        return fimg

    def __len__(self):
        return len(self.data)


class FFTDataset(Dataset):
    def __init__(self, data, transforms = None):
        self.folder_names = ['carrying', 'threat', 'normal']
        self.data = data
        self.transforms = transforms

    def __getitem__(self, idx):
        #         im_path = self.path[idx]
        #         label = self.path[idx].split('/')[-3]
        #         label = 1 if label == 'real' else 0
        data = self.data[idx]
        label = self.folder_names.index(data.parent.name)
        img = cv2.imread(str(data))
        ft_sample = self.generate_FT(img)
        ft_sample = cv2.resize(ft_sample, (80, 80))
        ft_sample = torch.from_numpy(ft_sample).float()
        ft_sample = torch.unsqueeze(ft_sample, 0)
        img = cv2.resize(img, (512, 512))
        if self.transforms:
            img = self.transforms(img)
#             im = np.round(im, 2)
        return img, label, ft_sample

    def generate_FT(self, image):
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        f = np.fft.fft2(image)
        fshift = np.fft.fftshift(f)
        fimg = np.log(np.abs(fshift)+1)
        maxx = -1
        minn = 100000
        for i in range(len(fimg)):
            if maxx < max(fimg[i]):
                maxx = max(fimg[i])
            if minn > min(fimg[i]):
                minn = min(fimg[i])
        fimg = (fimg - minn+1) / (maxx - minn+1)
        return fimg

    def __len__(self):
        return len(self.data)

In [36]:

test_transforms = T.Compose([
    T.ToPILImage(),
    # T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transforms = {
    'test': test_transforms
}
test_dataset = FFTDataset(
        test_data, transforms=transforms['test'])
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                             num_workers=num_workers, pin_memory=True, shuffle=True)

In [37]:
# Self-supervised learning
class FTGen_1(nn.Module):
    def __init__(self, in_channels=64, out_channels=1):
        super(FTGen_1, self).__init__()
        self.ft = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.ft(x)

### Improved Bi-FPN Model

In [38]:
class FourierBiFPN(nn.Module):
    def __init__(self, num_classes=3, compound_coef=0, load_weights=False, **kwargs):
        super(FourierBiFPN, self).__init__()
        
        self.compound_coef = compound_coef
        self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
        self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
        self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
        self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
        self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
        self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
        self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
        self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
        
        conv_channel_coef = {
            # the channels of P3/P4/P5.
            0: [40, 112, 320],
            1: [40, 112, 320],
            2: [48, 120, 352],
            3: [48, 136, 384],
            4: [56, 160, 448],
            5: [64, 176, 512],
            6: [72, 200, 576],
            7: [72, 200, 576],
        }

        num_anchors = len(self.aspect_ratios) * self.num_scales
        self.ftg1 = FTGen_1()
        self.upsample1 = nn.Upsample(size=(80, 80), mode="nearest")
        self.bifpn = nn.Sequential(
            *[BiFPN(self.fpn_num_filters[self.compound_coef],
                    conv_channel_coef[compound_coef],
                    True if _ == 0 else False,
                    attention=True if compound_coef < 6 else False)
              for _ in range(self.fpn_cell_repeats[compound_coef])])
        self.backbone_net = EfficientNet_Head(compound_coef=self.backbone_compound_coef[compound_coef])
        # Novelty
        self.p3_fc = nn.Linear(16384, 4096)
        self.p4_fc = nn.Linear(4096, 1024)
        self.final_fc = nn.Sequential(*[nn.Linear(1024, 256), nn.Linear(256, 3)])
    
    def forward(self, inputs):
        max_size = inputs.shape[-1]
        _, p3, p4, p5 = self.backbone_net(inputs)
        features = (p3, p4, p5)
        features = self.bifpn(features)
        
        p1,p2,p3,p4,p5 = features
        
        p3_f = nn.ReLU()(p3)
        p3_f = self.upsample1(p3_f)
        
        p4_f = nn.ReLU()(p4)
        p4_f = self.upsample1(p4_f)
        
        p5_f = nn.ReLU()(p5)
        p5_f = self.upsample1(p5_f)
        
        ft_3 = self.ftg1(p3_f)
        ft_4 = self.ftg1(p4_f)
        ft_5 = self.ftg1(p5_f)
        
        p3 = p3.reshape(p3.shape[0], -1)
        p4 = p4.reshape(p4.shape[0], -1)
        p5 = p5.reshape(p5.shape[0], -1)
        
        # Novelty
        
        p3 = self.p3_fc(p3)
        p4 = self.p4_fc(p4)
        pt = p3 + p4 + p5
        
        out = self.final_fc(pt)
        
        return out, ft_3, ft_4, ft_5

In [39]:
from torch.autograd import Variable

def compute_loss(network, img, labels, ft_feat, logger, phase, device):
    """
    Compute the losses, given the network, data and labels and 
    device in which the computation will be performed. 
    """
    # loss definitions
    criterion_ce = nn.CrossEntropyLoss()
    criterion_mse = nn.MSELoss()
    ft_feat = Variable(ft_feat.to(device))
    # print('labels', labels.shape, labels)
    y_one_hot = nn.functional.one_hot(labels, num_classes= 3).to(device).float()
    # print('labels', labels.shape, labels)

    out, ft_3, ft_4, ft_5 = network(img.to(device))
    _, predicted = torch.max(out.data, 1)

    # preds = torch.argmax(out, dim=1)
    # print(predicted)
    # print(out.shape, preds.shape, labels.shape)
    acc = float((predicted.cpu() == labels).sum())/float(out.shape[0])
    # print(preds, labelsv_binary, preds)
    
    loss1 = criterion_ce(out, y_one_hot)
    loss2 = criterion_mse(ft_3, ft_feat)
    loss3 = criterion_mse(ft_4, ft_feat)
    loss4 = criterion_mse(ft_5, ft_feat)
    
    
    loss_temp = (loss2 + loss3 + loss4)/3.0
    
    loss = 0.5 * loss1 + 0.5 * (loss_temp)
    return loss, acc, predicted.cpu()

In [40]:
import copy
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm 
from pathlib import Path

class Trainer(object):

    def __init__(self, network, optimizer, compute_loss, learning_rate=0.001, batch_size=32,
                 device='cpu', save_interval=2, save_path=''):

        self.network = network
        self.batch_size = batch_size
        self.optimizer = optimizer
        self.compute_loss = compute_loss
        self.device = device
        self.learning_rate = learning_rate
        self.save_interval = save_interval
        self.save_path = save_path
        self.network.to(self.device)
        
#         if torch.cuda.device_count() > 1:
#             self.network = nn.DataParallel(self.network, device_ids=[0]).cuda()

    def load_model(self, model_filename):

        cp = torch.load(model_filename)
        self.network.load_state_dict(cp['state_dict'])
        start_epoch = cp['epoch']
        start_iter = cp['iteration']
        losses = cp['loss']
        return start_epoch, start_iter, losses

    def save_model(self, output_dir, epoch=0, iteration=0, losses=None, accuracy=None):
   
        saved_filename = 'model_{}_{}.pth'.format(epoch, iteration)
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        saved_path = output_dir / saved_filename
        cp = {'epoch': epoch,
              'iteration': iteration,
              'loss': losses,
              'state_dict': self.network.cpu().state_dict()
              }
        torch.save(cp, saved_path)
        self.network.to(self.device)

    def test(self, dataloader):
    
        start_iter = 0
        losses = []
        val_loss_history = []

        
        val_acc_history = []
        for phase in ['val']:	
            self.network.eval() 
            tq = tqdm(dataloader[phase])
            for i, data in enumerate(tq):
                if i >= start_iter:
                    start = time.time()
                    img, labels, ft_feat = data
                    self.optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        loss, acc, ft_preds = self.compute_loss(
                            self.network, img, labels, ft_feat, self.writer, phase, self.device)
                        val_loss_history.append(loss.item())
                        val_acc_history.append(acc)
                        end = time.time()
                    print(
                        f"[{1}/{1}][{i}/{len(dataloader[phase])}] => LOSS: {loss.item()}, ACC: {acc}, (ELAPSED TIME: {(end - start)}), PHASE: {phase}")
                    tq.set_postfix(LOSS=loss.item(), ACC=acc, MODEL_PRED=ft_preds)
                    losses.append(loss.item())
        epoch_val_loss = np.mean(val_loss_history)
        epoch_val_acc = np.mean(val_acc_history)
            
        print(f"TEST LOSS: {epoch_val_loss}, ACC: {epoch_val_acc}")

        print(f"EPOCH DONE")

In [41]:
gpus = [0]
network = FourierBiFPN()
# if torch.cuda.device_count() > 1:
#     network = nn.DataParallel(network, device_ids=gpus).cuda(gpus[0])

for name, param in network.named_parameters():
    param.requires_grad = True

learning_rate = 0.001
weight_decay = 0.00001
optimizer = optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), 
                        lr=learning_rate, weight_decay=weight_decay)
                        
dataloaders = {'val': test_loader}

Loaded pretrained weights for efficientnet-b0


In [45]:
# load_dict = torch.load(model_path / 'bifpn_new_model.pth')
# print(load_dict.keys())

dict_keys(['epoch', 'iteration', 'loss', 'state_dict'])


In [None]:
trainer = Trainer(network, optimizer, compute_loss, learning_rate=learning_rate,
                      batch_size=batch_size, device=f'cuda:{gpus[0]}' if torch.cuda.is_available() else 'cpu')
trainer.load_model(model_path / 'bifpn_new_model.pth')

trainer.test(dataloaders)

### Old Bi-FPN Model

In [48]:
class OldFourierBiFPN(nn.Module):
    def __init__(self, num_classes=3, compound_coef=0, load_weights=False, **kwargs):
        super(OldFourierBiFPN, self).__init__()
        
        self.compound_coef = compound_coef
        self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
        self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
        self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
        self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
        self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
        self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
        self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
        self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
        
        conv_channel_coef = {
            # the channels of P3/P4/P5.
            0: [40, 112, 320],
            1: [40, 112, 320],
            2: [48, 120, 352],
            3: [48, 136, 384],
            4: [56, 160, 448],
            5: [64, 176, 512],
            6: [72, 200, 576],
            7: [72, 200, 576],
        }

        num_anchors = len(self.aspect_ratios) * self.num_scales
        self.ftg1 = FTGen_1()
        self.upsample1 = nn.Upsample(size=(80, 80), mode="nearest")
        self.bifpn = nn.Sequential(
            *[BiFPN(self.fpn_num_filters[self.compound_coef],
                    conv_channel_coef[compound_coef],
                    True if _ == 0 else False,
                    attention=True if compound_coef < 6 else False)
              for _ in range(self.fpn_cell_repeats[compound_coef])])
        self.backbone_net = EfficientNet_Head(compound_coef=self.backbone_compound_coef[compound_coef])
    
    def forward(self, inputs):
        max_size = inputs.shape[-1]
        _, p3, p4, p5 = self.backbone_net(inputs)
        features = (p3, p4, p5)
        features = self.bifpn(features)
        
        p1,p2,p3,p4,p5 = features
        
        p3_f = nn.ReLU()(p3)
        p3_f = self.upsample1(p3_f)
        
        p4_f = nn.ReLU()(p4)
        p4_f = self.upsample1(p4_f)
        
        p5_f = nn.ReLU()(p5)
        p5_f = self.upsample1(p5_f)
        
        ft_3 = self.ftg1(p3_f)
        ft_4 = self.ftg1(p4_f)
        ft_5 = self.ftg1(p5_f)
        
        p3 = p3.reshape(p3.shape[0], -1)
        p4 = p4.reshape(p4.shape[0], -1)
        p5 = p5.reshape(p5.shape[0], -1)

        # PRIOR MODEL STRUCTURE
        
        p3 = torch.sigmoid(p3).mean(dim=1)
        p4 = torch.sigmoid(p4).mean(dim=1)
        p5 = torch.sigmoid(p5).mean(dim=1)
        # print(p3.shape, p4.shape, p5.shape)
        out = torch.stack((p3, p4, p5), dim=1) # .mean(dim=1)
        # print(out.shape)
        out = nn.functional.softmax(out, dim=1)
        # print(out.shape)
        
        return out, ft_3, ft_4, ft_5

In [49]:
load_dict = torch.load(model_path / 'bifpn_old_model.pth')
print(load_dict.keys())

dict_keys(['epoch', 'iteration', 'loss', 'state_dict'])


In [None]:
network = OldFourierBiFPN()

for name, param in network.named_parameters():
    param.requires_grad = True

optimizer = optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), 
                        lr=learning_rate, weight_decay=weight_decay)

trainer = Trainer(network, optimizer, compute_loss, learning_rate=learning_rate,
                      batch_size=batch_size, device=f'cuda:{gpus[0]}' if torch.cuda.is_available() else 'cpu')
trainer.load_model(model_path / 'bifpn_old_model.pth')

trainer.test(dataloaders)