![Rampart Logo](../images/logo.png)

Auge is an image classification model. Its main target's to recognize common photos to determine a few flat publication's features. Each image belongs to a specific realty, having a bunch of photo & panorama recognized **twinkle** can better predict apartments' order.

## I/O
Required images lie at `../scientific/images` . Currently the DB containes images of two types: simple photos & wide (360 deg) panoramas. Until a stable classification model panoramas should be omitted. Final classifier must be stored into `../models/auge.latest.pth` .

## Metadata
Each image file has a self-explained name looking like:
```
<hash>.<effect>.<group>.<label>.webp
```
For instance:
```
0018559490dd9cb73caa00f078fde40b220803de.balance_down_hue_cw_rotate_ccw_crop.training.construction.webp
```
Placeholders in angle brackets:
- `hash` , SHA-1 sum of the image URL.
- `effect` , image filters used to augment the initial photo. An unchaged file has `origin` here.
- `group` , the dataset image belongs to. Can be one of `training` , `validation` & `testing` .
- `label` , expected image class. See the details below.

## Classes
- `luxury` is a flat with rich furniture, huge rooms, chandeliers, fireplaces, etc.

![Luxury 1](../images/luxury1.webp)
![Luxury 2](../images/luxury2.webp)
![Luxury 3](../images/luxury3.webp)

- `comfort` is the most suitable for an ordinary citizen apartments. Clean, neat, sometimes minimalistic, average area, qualitive furniture, etc.

![Comfort 1](../images/comfort1.webp)
![Comfort 2](../images/comfort2.webp)
![Comfort 3](../images/comfort3.webp)

- `junk` is an old flat image. Probably, the whole apartments should belong to a dormitory, Khrushchevka or gostinka.

![Junk 1](../images/junk1.webp)
![Junk 2](../images/junk2.webp)
![Junk 3](../images/junk3.webp)

- `construction` is a flat without a finished design. No doors, floor, supplies, wallpapers, ceiling, furniture, etc. Typically, new buildings contain these apartments.

![Construction 1](../images/construction1.webp)
![Construction 2](../images/construction2.webp)
![Construction 3](../images/construction3.webp)

- `excess` is the trash category. Actually, all exterior photos, outlines & posters lie here.

![Excess 1](../images/excess1.webp)
![Excess 2](../images/excess2.webp)
![Excess 3](../images/excess3.webp)

In [1]:
%matplotlib inline

In [87]:
from plotly.graph_objs import Pie, Figure, Scatter
from plotly.figure_factory import create_annotated_heatmap
from plotly.express import imshow
from plotly.subplots import make_subplots
from re import match
from numpy import arange, trace, sum
from glob import glob
from pandas import DataFrame, concat
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.nn import Module, Conv2d, Linear, NLLLoss
from torch.nn.functional import log_softmax, relu, max_pool2d, dropout
from torch.optim import Adam
from torch import no_grad, save, max, load, zeros, cat, long, device, set_num_threads, get_num_threads
from uuid import uuid4
from PIL.Image import open
from multiprocessing import cpu_count
from IPython.display import display 

In [3]:
set_num_threads(cpu_count())
print(f'Set thread number to {get_num_threads()}.')

Set thread number to 8.


In [4]:
def draw(x):
    imshow(x.permute(1, 2, 0)).show()

In [11]:
groups = {'training', 'validation', 'testing'}
labels = ['luxury', 'comfort', 'junk', 'construction', 'excess']

In [12]:
def parse(path, mappings):
    result = match(r'^.*/\w+\.\w+\.(\w+)\.(\w+)\.webp$', path)
    if not result:
        raise RuntimeError(f'Got invalid path, {path}')
    expressions = result.groups()
    if expressions[0] not in groups:
        raise RuntimeError(f'Got invalid group, {path}')
    if expressions[1] not in mappings:
        raise RuntimeError(f'Got invalid label, {path}')
    return path, expressions[0], mappings[expressions[1]]

In [13]:
def extract():
    mappings = {l: i for i, l in enumerate(labels)}
    return DataFrame(
        map(lambda p: parse(p, mappings), glob('../scientific/images/*.webp')),
        columns=['path', 'group', 'label']
    )

In [30]:
images = extract()
figure = make_subplots(
    cols=len(groups),
    specs=[[{'type': 'domain'}] * len(groups)],
    subplot_titles=list(groups)
)
for i, group in enumerate(groups, 1):
    counts = images[images['group'] == group]['label'].value_counts().sort_index()
    figure.add_trace(
        Pie(labels=[labels[j] for j in counts.index], values=counts.values, name=''),
        row=1,
        col=i
    )
figure.show()

In [31]:
class Gallery(Dataset):
    def __init__(self, data):
        self._data = data.values
        self._transforms = Compose(
            [
                ToTensor(),
                Resize((460, 620)),
                Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
        )

    def __getitem__(self, index):
        return self._transforms(open(self._data[index][0])), self._data[index][2], self._data[index][0]

    def __len__(self):
        return len(self._data)

In [32]:
batch_size = 8
training_loader = DataLoader(Gallery(images[images['group'] == 'training']), batch_size, True)
validation_loader = DataLoader(Gallery(images[images['group'] == 'validation']), batch_size)
testing_loader = DataLoader(Gallery(images[images['group'] == 'testing']), batch_size)
print(len(training_loader), len(validation_loader), len(testing_loader))

3944 492 764


In [88]:
class Network(Module):
    def __init__(self):
        super().__init__()
        self._conv1 = Conv2d(3, 12, 11, 4)
        self._conv2 = Conv2d(12, 48, 5, 2)
        self._fc = Linear(47952, 5)

    def forward(self, x):
        x = self._conv1(x)
        x = relu(x)
        x = self._conv2(x)
        x = relu(x)
        x = max_pool2d(x, 2)
        x = x.view(-1, 47952)
        x = dropout(x, 0.4)
        x = self._fc(x)
        return log_softmax(x, 1)

In [86]:
def train():
    network = Network()
    criterion = NLLLoss()
    optimizer = Adam(network.parameters(), weight_decay=0.0001)
    epoch_number = 5
    epochs = arange(epoch_number)
    training_losses = [0.0] * epoch_number
    validation_losses = [0.0] * epoch_number
    for epoch in epochs:
        network.train()
        for batch in training_loader:
            optimizer.zero_grad()
            loss = criterion(network(batch[0]), batch[1])
            loss.backward()
            optimizer.step()
            training_losses[epoch] += loss.item() * batch[0].size(0)
        training_losses[epoch] /= len(training_loader.sampler)
        network.eval()
        with no_grad():
            for batch in validation_loader:
                validation_losses[epoch] += criterion(network(batch[0]), batch[1]).item() * batch[0].size(0)
        validation_losses[epoch] /= len(validation_loader.sampler)
    state = network.state_dict()
    save(state, f'../scientific/models/auge.{uuid4().hex}.pth')
    save(state, '../scientific/models/auge.latest.pth')
    figure = Figure()
    figure.add_trace(Scatter(x=epochs, y=training_losses, name='Training'))
    figure.add_trace(Scatter(x=epochs, y=validation_losses, name='Validation'))
    figure.show()
    return network

In [35]:
def use(tag='latest'):
    network = Network()
    network.load_state_dict(load(f'../scientific/models/auge.{tag}.pth'))
    return network

In [53]:
def test(network):
    network.eval()
    predicted, actual = zeros(0, dtype=long), zeros(0, dtype=long)
    with no_grad():
        for batch in testing_loader:
            predicted = cat([predicted, max(network(batch[0]), 1)[1].view(-1)])
            actual = cat([actual, batch[1].view(-1)])
    matrix = confusion_matrix(actual, predicted)
    figure = create_annotated_heatmap(z=matrix, x=labels, y=labels, hoverinfo='skip')
    figure.update_xaxes(title_text='Predicted')
    figure.update_yaxes(title_text='Actual', autorange='reversed')
    figure.update_layout(title=f'Accuracy: {trace(matrix) / sum(matrix) * 100:.2f}%')
    figure.show()

In [49]:
def explain(network):
    network.eval()
    with no_grad():
        for batch in testing_loader:
            for result in zip(max(network(batch[0]), 1)[1], batch[1], batch[2]): 
                if result[0] != result[1]:
                    print(f'Predicted {labels[result[0]]}, actual {labels[result[1]]}')
                    display(open(result[2]))

In [84]:
%%time
for _ in range(1):
    test(train())

CPU times: user 6h 10min 9s, sys: 16.2 s, total: 6h 10min 26s
Wall time: 47min 43s


In [85]:
%%time
test(use())

CPU times: user 10min 19s, sys: 537 ms, total: 10min 20s
Wall time: 1min 21s


In [None]:
explain(use())