In [None]:
# data
import pandas as pd
import numpy as np
import string
import re
import os
from ast import literal_eval

# viz
import matplotlib.pyplot as plt
from PIL import Image

# parallel
from joblib import Parallel, delayed
import multiprocessing

# custom
import importlib
import proj_funs
importlib.reload(proj_funs)
from proj_funs import extract_subjects, read_saved, save_covers

# ML
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import torchvision.models as models
import torch
from sklearn.metrics import precision_score, recall_score, f1_score
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torchvision.models.resnet import ResNet, BasicBlock

# count cpus
multiprocessing.cpu_count()

# Import subject data

In [None]:
df = proj_funs.read_saved("df_clean_uniqkey.csv")

In [None]:
df

In [None]:
lables = list(set(np.concatenate(df.subjects.values).flat))
lables

## Subset data

In [None]:
N = 10**4+2000
df_sample = df.sample(n=N, random_state=1).reset_index(drop=True)
df_sample

# Img processing

## Download imgs

In [None]:
image_ids = np.ravel(df_sample.cover.values)
image_size = "M"
#save_covers(image_ids, image_size)
Parallel(n_jobs=100)(delayed(save_covers)([i], image_size) for i in image_ids)
print("DONE!")

# Join Data

In [None]:
img_folder = 'covers'
size = "M"

In [None]:
class OLDataset(Dataset):
    """subjects+cover dataset object definition
    methods: indexing, len
    """
    def __init__(self, data_path, samples, transforms):
        self.transforms = transforms
        self.classes = lables

        self.imgs = []
        self.annos = []
        self.data_path = data_path
        for k, sample in samples.iterrows():
            self.imgs.append(str(sample['cover'])+"-"+size+".jpg")
            self.annos.append(sample['subjects'])
        for item_id in range(len(self.annos)):
            item = self.annos[item_id]
            vector = [cls in item for cls in self.classes]
            self.annos[item_id] = np.array(vector, dtype=float)

    def __getitem__(self, item):
        anno = self.annos[item]
        img_path = os.path.join(self.data_path, self.imgs[item])
        img = Image.open(img_path)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, anno

    def __len__(self):
        return len(self.imgs)

## Split train/val/test

In [None]:
train_pct = 0.8334
test_pct = 0.16667/2
val_pct = 1-train_pct-test_pct

In [None]:
dataset_train = OLDataset("covers", df_sample[:int(len(df_sample)*train_pct)+1], None)
dataset_val = OLDataset("covers", df_sample[int(len(df_sample)*train_pct)+1:int(len(df_sample)*train_pct)+1+int(len(df_sample)*val_pct)+1], None)
dataset_test = OLDataset("covers", df_sample[int(len(df_sample)*train_pct)+1+int(len(df_sample)*val_pct)+1:], None)

In [None]:
print(dataset_val[0])

# A simple function for visualization.
def show_sample(img, binary_img_labels):
    # Convert the binary labels back to the text representation.    
    img_labels = np.array(dataset_val.classes)[np.argwhere(binary_img_labels > 0)[:, 0]]
    plt.imshow(img)
    plt.title("{}".format(', '.join(img_labels)))
    plt.axis('off')
    plt.show()

for sample_id in range(50,70):
    show_sample(*dataset_val[sample_id])

In [None]:
# Calculate label distribution for the entire dataset (train + test)
samples = dataset_val.annos + dataset_train.annos
samples = np.array(samples)
with np.printoptions(precision=3, suppress=True):
    class_counts = np.sum(samples, axis=0)
    # Sort labels according to their frequency in the dataset.
    sorted_ids = np.array([i[0] for i in sorted(enumerate(class_counts), key=lambda x: x[1])], dtype=int)
    print('Label distribution (count, class name):', list(zip(class_counts[sorted_ids].astype(int), np.array(dataset_val.classes)[sorted_ids])))
    plt.barh(range(len(dataset_val.classes)), width=class_counts[sorted_ids])
    plt.yticks(range(len(dataset_val.classes)), np.array(dataset_val.classes)[sorted_ids])
    plt.gca().margins(y=0)
    plt.grid()
    plt.title('Label distribution')
    plt.show()

## Preprocess imgs

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [None]:
# Test preprocessing
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
print(tuple(np.array(np.array(mean)*255).tolist()))

# Train preprocessing
train_transform = transforms.Compose([
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# Classification
https://github.com/spmallick/learnopencv/tree/master/PyTorch-Multi-Label-Image-Classification-Image-Tagging

## Dataloaders

In [None]:
# init the dataloaders for training
batch_size = 1

test_dataset = OLDataset("covers", df_sample[:int(len(df_sample)*train_pct)+1], val_transform)
train_dataset = OLDataset("covers", df_sample[int(len(df_sample)*train_pct)+1+int(len(df_sample)*val_pct)+1:], train_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

num_train_batches = int(np.ceil(len(train_dataset) / batch_size))

## Model

In [None]:
# https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
# based on pytorch resnet18 implementation
# rmv linear layers
# add final convolutional layer
# and a Sigmoid instead of a default Softmax.
class FCResNet18(ResNet):
    def __init__(self, n_classes):
        super().__init__(BasicBlock, [2, 2, 2, 2])
        self.sigm = nn.Sigmoid()
        self.final_conv = nn.Conv2d(in_channels=512, out_channels=22, kernel_size=1)
        self.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=512, out_features=n_classes)
        )

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.final_conv(x)
        x = torch.max(torch.max(x,0).values,1).values
        x = torch.transpose(x,1,0)
        # x = torch.flatten(x, 0)
        # x = self.fc(x)  # (1x1000 x 1000x22) add linear layers?
        x = self.sigm(x)

        return x
            

In [None]:
m = nn.Sigmoid()
input = torch.randn(2)
m(input)

In [None]:
x = torch.randn(1, 22, 1, 1)
# print(x)
x = torch.max(torch.max(x,0).values,1).values
x = torch.transpose(x,1,0)
# torch.max(x,0).values.shape
x

## Train Model

### Training parameters & metrics

In [None]:
# Initialize the training parameters.
lr = 1e-4 # Learning rate
test_freq = 200 # Test model frequency (iterations)
max_epoch_number = 35 # Max num of training epochs 

# Initialize the model
model = FCResNet18(len(lables))
# model.load_state_dict(models.resnet18(pretrained=True).state_dict())

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

In [None]:
# Use threshold to define predicted labels and invoke sklearn's metrics with different averaging strategies.
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'),
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'),
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'),
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'),
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'),
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'),
            'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'),
            'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'),
            'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'),
            }

### Train Model

In [None]:
# Run training
model.train()
epoch = 0
iteration = 0
while True:
    batch_losses = []
    for imgs, targets in train_dataloader:
        optimizer.zero_grad()

        model_result = model(imgs)
        loss = criterion(model_result, targets.type(torch.float))

        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()

        batch_losses.append(batch_loss_value)
        with torch.no_grad():
            result = calculate_metrics(model_result, targets)

        if iteration % test_freq == 0:
            model.eval()
            with torch.no_grad():
                model_result = []
                targets = []
                for imgs, batch_targets in test_dataloader:
                    imgs = imgs
                    model_batch_result = model(imgs)
                    model_result.extend(model_batch_result)
                    targets.extend(batch_targets)

            result = calculate_metrics(np.array(model_result), np.array(targets))
            print("epoch:{:2d} iter:{:3d} test: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch, iteration,
                                              result['micro/f1'],
                                              result['macro/f1'],
                                              result['samples/f1']))

            model.train()
        iteration += 1

    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(epoch, iteration, loss_value))
    epoch += 1
    if max_epoch_number < epoch:
        break

In [None]:
# Run inference on the test data
model.eval()
for sample_id in [1,2,3,4,6]:
    test_img, test_labels = test_dataset[sample_id]
    test_img_path = os.path.join(img_folder, test_dataset.imgs[sample_id])
    with torch.no_grad():
        raw_pred = model(test_img.unsqueeze(0))[0]
        raw_pred = np.array(raw_pred > 0.5, dtype=float)

    predicted_labels = np.array(dataset_val.classes)[np.argwhere(raw_pred > 0)[:, 0]]
    if not len(predicted_labels):
        predicted_labels = ['no predictions']
    img_labels = np.array(dataset_val.classes)[np.argwhere(test_labels > 0)[:, 0]]
    plt.imshow(Image.open(test_img_path))
    plt.title("Predicted labels: {} \nGT labels: {}".format(', '.join(predicted_labels), ', '.join(img_labels)))
    plt.axis('off')
    plt.show()