In [23]:
import random

import numpy as np
import matplotlib.pyplot as plt
import torch

from tqdm import tqdm

from torch import nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split, DataLoader

torch.set_float32_matmul_precision('medium')

from dlc.trainers.img_classification import launch_training

In [24]:
def is_black_img(sample, threshold=0.01):
    img = sample[0]
    percentile_99_value = torch.quantile(img, 0.99)
    return percentile_99_value < threshold

In [25]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
dataset = ImageFolder(root="D:/data/images/galaxy10_unamur/train", transform=transform)
train_dataset, test_dataset = random_split(dataset, lengths=(0.8, 0.2))
print("#train samples:", len(train_dataset))
print("#test samples:", len(test_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print("#train batches:", len(train_dataloader))
print("#test batches:", len(test_dataloader))

#train samples: 14691
#test samples: 3672
#train batches: 230
#test batches: 58


In [26]:
sum([is_black_img(s) for s in tqdm(train_dataset)])

100%|██████████| 14691/14691 [03:06<00:00, 78.61it/s]


tensor(1860)

In [None]:
def display_sample(sample, unnormalize=True, ax=None):
    ax = plt.gca() if ax is None else ax
    img, label = sample
    img = img * 0.5 + 0.5 if unnormalize else img
    ax.set_title(dataset.classes[label])
    ax.axis('off')
    ax.imshow(img.permute(1, 2, 0).cpu().numpy())

def display_batch(batch, unnormalize=True):
    imgs, labels = batch
    samples = random.sample(list(zip(imgs, labels)), k=8)
    plt.figure(figsize=(16, 8))
    for i, sample in enumerate(samples):
        plt.subplot(2, 4, i + 1)
        display_sample(sample, unnormalize=unnormalize)
    plt.tight_layout()
    plt.show()

In [None]:
display_batch(next(iter(train_dataloader)))

In [None]:
model = resnet18()
model.fc = nn.Linear(512, len(dataset.classes))
model

In [None]:
launch_training(model=model, train_dataloader=train_dataloader, val_dataloader=test_dataloader)