In [None]:
from mlcpl import datasets
from torchvision import transforms
from mlcpl.augs import LogicMix

image_size = (224, 224)

train_dataset_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(interpolation = transforms.functional.InterpolationMode.BILINEAR),
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

valid_dataset_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

train_dataset = datasets.MSCOCO(
    '/home/max/datasets/COCO',
    split = 'train',
    transform = train_dataset_transforms
    ).drop_labels_uniform(target_label_proportion=0.1)

train_dataset = LogicMix(train_dataset, probability=0.1, k_min=2, k_max=3)

valid_dataset = datasets.MSCOCO(
    '/home/max/datasets/COCO',
    split='valid',
    transform = valid_dataset_transforms
    )


In [None]:
from torchvision.models import *
from torch import nn

model = resnet18(weights='IMAGENET1K_V1')
model.fc = nn.Linear(512, train_dataset.num_categories)
model = model.to('cuda')
 

In [None]:
from mlcpl.losses import *
from torch.optim import Adam
from mlcpl.metrics import *


loss_fn = PartialAsymmetricWithLogitLoss(
    gamma_neg=4,
    gamma_pos=0,
    clip=0.05,
    reduction='sum'
    )

optimizer = Adam(model.parameters(), lr=0.001)

metrics = {
    'mAP': partial_multilabel_average_precision,
    'mAUC': partial_multilabel_auroc,
    'mF1': partial_multilabel_f1_score,
}

In [None]:
import torch

def train(model, dataloader, loss_fn, optimizer, device='cuda'):

    model.train()

    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fn(pred, y)
        loss.backward()

        optimizer.step()
        model.zero_grad()

        if batch % 10 == 0:
            print(f'Training... Batch: {batch}/{len(dataloader)}, Train loss: {loss:.4f}', end='\r')

    print()

def evaluate(model, dataloader, metrics, device='cuda'):
    num_samples = len(dataloader.dataset)
    num_categories = dataloader.dataset.num_categories
    batch_size = dataloader.batch_size

    preds = torch.zeros((num_samples, num_categories))
    targets = torch.zeros((num_samples, num_categories))

    model.eval()

    with torch.no_grad():
        for batch, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            pred = model(x)
            preds[batch*batch_size: (batch+1)*batch_size, :] = pred.detach().cpu()
            targets[batch*batch_size: (batch+1)*batch_size, :] = y.detach().cpu()

            if batch % 10 == 0:
                print(f'Validating... Batch: {batch}/{len(dataloader)}', end='\r')
        print()

    for name, metric in metrics.items():
        result = metric(preds, targets).detach().numpy()
        print(f'valid_{name}: {result:.4f}')

In [None]:
from torch.utils.data import DataLoader


num_epoches = 20
batch_size = 256
num_workers = 16

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers)

for epoch in range(num_epoches):
    print(f'Epoch: {epoch+1}')
    train(model, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer)
    evaluate(model, dataloader=valid_dataloader, metrics=metrics)