In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
!rsync --archive --ignore-existing --human-readable --info progress2 "/content/drive/MyDrive/BAM/" "/tmp/BAM/"

In [None]:
import pickle
import sqlite3
import tarfile
from datetime import datetime
from io import BytesIO
from os import cpu_count
from pathlib import Path
from typing import List, Tuple

import pandas as pd
import torch
import torch.optim
from pandas import DataFrame
from PIL import Image
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.nn import CrossEntropyLoss, Linear
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    RandomVerticalFlip,
    ToTensor,
)
from tqdm import tqdm, trange

In [None]:
torch.manual_seed(0)

In [None]:
!nvidia-smi

In [None]:
BATCH_SIZE = 64

In [None]:
def make_index(path: Path) -> DataFrame:
    with tarfile.open(path) as tar:
        indices = {}

        for file in tqdm(tar, desc="Loading images"):
            name = Path(file.name).stem

            # The database indices are numbers; we need this conversion in order
            # to perform a join later.
            # TODO: Maybe `continue` on failure?
            index = int(name)

            indices[index] = {
                "offset": file.offset_data,
                "length": file.size,
            }

    return DataFrame.from_dict(indices, orient="index")


def get_labels(path: Path) -> DataFrame:
    QUERY = """
    SELECT
        `mid`, `media_3d_graphics`, `media_comic`, `media_graphite`, `media_oilpaint`, `media_pen_ink`, `media_vectorart`, `media_watercolor`
    FROM
        `automatic_labels`
    """

    with sqlite3.connect(path) as conn:
        labels = pd.read_sql(QUERY, conn, index_col=["mid"])

    return labels.replace({"negative": 0.0, "positive": 1.0, "unsure": None}).dropna()


def get_image_by_offset(path: Path, offset: int, length: int) -> Image:
    with open(path, "rb") as file:
        file.seek(offset)
        return Image.open(BytesIO(file.read(length))).convert("RGB")


class BehanceDataset(Dataset[Tuple[Tensor, Tensor]]):
    data: DataFrame

    labels: List[str] = [
        "media_3d_graphics",
        "media_comic",
        "media_graphite",
        "media_oilpaint",
        "media_pen_ink",
        "media_vectorart",
        "media_watercolor",
    ]

    def __init__(self, images_path: Path, metadata_path: Path):
        super().__init__()

        labels = get_labels(metadata_path)
        indices = make_index(images_path)

        self.images = images_path
        self.metadata = labels.join(indices, how="inner")

        self.transform = Compose(
            [
                RandomResizedCrop(size=(224, 224)),
                RandomHorizontalFlip(),
                RandomVerticalFlip(),
                ToTensor(),
                Normalize(
                    mean=[0.7137, 0.6628, 0.6519],
                    std=[0.2970, 0.3017, 0.2979],
                ),
            ]
        )

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
        length, offset = self.metadata.iloc[index].loc[["length", "offset"]].astype(int)

        image = get_image_by_offset(self.images, offset, length)
        label = self.metadata.iloc[index].loc[self.labels]

        return self.transform(image), torch.argmax(Tensor(label))

    def __len__(self):
        return len(self.metadata)

In [None]:
dataset = BehanceDataset(
    Path("/tmp/BAM/20171214-behance-styles-crowd-only-images-R4OHZduT.tar"),
    Path("/tmp/BAM/20171214-behance-styles-crowd-only-R4OHZduT.sqlite"),
)

In [None]:
valid = (dataset.metadata[dataset.labels] != 0).transpose().any()
dataset.metadata = dataset.metadata[valid]

In [None]:
indices = list(range(len(dataset)))

try:
    with open("/content/drive/MyDrive/BAM/indices.pkl", "rb") as file:
        train_with, test_with, validate_with = pickle.load(file)
except:
    train_with, test_with = train_test_split(
        indices, test_size=0.1, shuffle=True, random_state=0
    )
    train_with, validate_with = train_test_split(
        train_with, test_size=0.1, shuffle=True, random_state=0
    )

    with open("/content/drive/MyDrive/BAM/indices.pkg", "wb") as file:
        pickle.dump((train_with, test_with, validate_with), file)

training_data_loader = DataLoader(
    Subset(dataset, train_with),
    BATCH_SIZE,
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
)
validation_data_loader = DataLoader(
    Subset(dataset, validate_with),
    BATCH_SIZE,
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
)
testing_data_loader = DataLoader(
    Subset(dataset, test_with),
    BATCH_SIZE,
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
)

In [None]:
print("Training set size:", len(train_with))
print("Validation set size:", len(validate_with))
print("Test set size:", len(test_with))

In [None]:
def train(
    model,
    training_loader,
    validation_loader,
    epochs,
    learning_rate,
    device,
    patience,
    decrease_lr_interval,
    save_to: Path,
):
    save_to.mkdir(exist_ok=True, parents=True)

    model = model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = CrossEntropyLoss()

    with trange(epochs) as progress:
        validation_accuracy = 0.0
        last_saved = 0

        for epoch in progress:
            model.train()

            if epoch % decrease_lr_interval == 0:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = learning_rate * (
                        0.5 ** (epoch // decrease_lr_interval)
                    )

            training_accuracy = 0.0

            for batch, (x, y) in enumerate(training_loader, start=1):
                optimizer.zero_grad()

                x = x.to(device)
                y = y.to(device)

                z = model(x)

                loss = criterion(z, y)
                loss.backward()

                optimizer.step()

                z = torch.argmax(z, dim=1)

                accuracy = accuracy_score(y.cpu(), z.cpu())
                training_accuracy += (accuracy - training_accuracy) / batch

                progress.set_postfix(
                    batch=batch,
                    training_accuracy=training_accuracy,
                    validation_accuracy=validation_accuracy,
                    last_saved=last_saved,
                )

            accuracy = validate(model, validation_loader, device)
            if accuracy >= validation_accuracy:
                validation_accuracy = accuracy
                last_saved = epoch

                state = {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "validation_accuracy": validation_accuracy,
                }

                torch.save(state, save_to / f"{epoch}.pth")

            progress.set_postfix(
                batch=batch,
                training_accuracy=training_accuracy,
                validation_accuracy=validation_accuracy,
                last_saved=last_saved,
            )

            if epoch - last_saved >= patience:
                break

In [None]:
def validate(model, loader, device):
    model.eval()

    epoch_mean_accuracy = 0
    for batch, (x, y) in enumerate(loader, start=1):
        x = x.to(device)
        y = y.to(device)

        z = torch.argmax(model(x), dim=1)

        accuracy = accuracy_score(y.cpu(), z.cpu())
        epoch_mean_accuracy += (accuracy - epoch_mean_accuracy) / batch

    return epoch_mean_accuracy

In [None]:
from sklearn.metrics import confusion_matrix


def test(model, loader, device):
    model.eval()

    y_pred = []
    y_true = []

    epoch_mean_accuracy = 0

    for batch, (x, y) in enumerate(loader, start=1):
        x = x.to(device)
        y = y.to(device)

        z = torch.argmax(model(x), dim=1)

        accuracy = accuracy_score(y.cpu(), z.cpu())
        epoch_mean_accuracy += (accuracy - epoch_mean_accuracy) / batch

        y_pred.append(z)
        y_true.append(y)

    y_pred = torch.cat(y_pred).cpu()
    y_true = torch.cat(y_true).cpu()

    print(confusion_matrix(y_true, y_pred))

    return epoch_mean_accuracy

In [None]:
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50")

for param in model.parameters():
    param.requires_grad = False

# Head
model[-1][-1] = Linear(512, 7)

for param in model[-1].parameters():
    param.requires_grad = True

In [None]:
device = torch.device("cuda:0")
model = model.to(device)

epochs = 100
lr = 1e-3

net = train(
    model,
    training_data_loader,
    validation_data_loader,
    epochs,
    lr,
    device,
    patience=100,
    decrease_lr_interval=20,
    save_to=Path("/content/drive/MyDrive/Behance/Models")
    / "ResNet50"
    / str(datetime.now()),
)

In [None]:
test(model, testing_data_loader, device)