# Lithology prediction by core images

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import pickle
import collections
import matplotlib.pyplot as plt
import torch
import time

os.environ["CUDA_VISIBLE_DEVICES"]="1"#,1,2,4"

sys.path.insert(0, os.path.join("..", "..", ".."))

from petroflow.batchflow.models.torch import UNet, ResNet18, ResNet34
from petroflow.batchflow import Dataset, FilesIndex, Pipeline, V, B, inbatch_parallel, I, W, F, ImagesBatch

In [2]:
# import shutil

# for item in list(os.walk('/notebooks/data/august_dataset/crops'))[0][1]:
#     if item not in ['Песчаник', 'Алевролит', 'Уголь', 'Аргиллит']:
#         shutil.move('/notebooks/data/august_dataset/crops/'+item, '/notebooks/data/august_dataset/other_crops/'+item)

## Dataset

In [3]:
index = FilesIndex(path='/notebooks/data/august_dataset/crops/*/*_dl.png')
ds = Dataset(index=index, batch_class=ImagesBatch)
ds.split(0.8, shuffle=42)

annotation = pd.read_feather('/notebooks/data/august_dataset/crops/annotation.feather')

In [4]:
annotation.NAME += '_dl.png'
annotation = annotation.set_index('NAME')

In [5]:
load_ppl = (
    Pipeline()
    .load(fmt='image', dst='images')
    .load(src=annotation.LITHOLOGY, dst='labels')
    .update(B('labels'), B('labels').values)
)

In [6]:
counter_ppl = (
    Pipeline()
    .init_variable('lithology', default=[])
    .update(V('lithology', mode='e'), B('labels'))
)

In [7]:
# import collections

# ppl = (load_ppl + counter_ppl << ds)
# (ppl
#  .after
#  .add_namespace(collections)
#  .init_variable('counter')
#  .Counter(V('lithology'), save_to=V('counter'))
# )

# ppl.run(10, bar=True, n_epochs=1)
# ppl.v('counter')

In [8]:
# labels_mapping = {i: k for k, i in enumerate(ppl.v('counter'))}

# with open('resnet/labels_mapping', 'wb') as f:
#     pickle.dump(labels_mapping, f)

with open('resnet/labels_mapping', 'rb') as f:
    labels_mapping = pickle.load(f)

reverse_mapping = {v: k for k, v in labels_mapping.items()}
labels_mapping

{'Алевролит': 2, 'Аргиллит': 1, 'Песчаник': 3, 'Уголь': 0}

In [9]:
def encode(labels, mapping):
    return np.array([mapping[item] for item in labels])

BATCH_SIZE = 8
N_EPOCH = 50
SHAPE = (3, 500, 250)

model_config = {'initial_block/inputs': 'images',
                'inputs/images/shape': SHAPE,
                'inputs/labels/classes': len(labels_mapping),
                'initial_block/inputs': 'images',
                'optimizer': 'Adam',
                'output': 'proba',
                'device': 'gpu:0',
                'loss': 'ce'}

train_tmp = (Pipeline()
    .add_namespace(np)
    .crop(src='images', dst='images', origin='random', shape=(SHAPE[2], SHAPE[1]))
    .to_array(src='images', dst='images', dtype='float32')
    .init_variable('loss', default=[])
    .transpose(B('images'), axes=(0, 3, 1, 2), save_to=B('images'))
    .encode(B('labels'), labels_mapping, save_to=B('labels'))
    .init_model('dynamic', ResNet18, 'model', model_config)
    .train_model('model', B('images'), B('labels'), use_lock=True, fetches='loss',
             save_to=V('loss', mode='a'))
)

In [None]:
train_ppl = (load_ppl + train_tmp << ds.train)
train_ppl.run(16, n_epochs=100, shuffle=42, bar=True, prefetch=3)#, bar_desc=W(V('loss')[-1]))

  2%|▏         | 541/27294 [02:42<2:39:21,  2.80it/s]

In [None]:
plt.plot(np.array(train_ppl.v('loss')))
plt.plot(pd.Series(np.array(train_ppl.v('loss'))).rolling(100).mean())

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import numpy as np

with open('resnet/loss', 'rb') as f:
    loss = pickle.load(f)

In [None]:
plt.plot(loss)
plt.plot(pd.Series(np.array(loss)).rolling(100).mean())

In [None]:
train_ppl.get_model_by_name('model').save('resnet/model.torch')

with open('resnet/loss', 'wb') as f:
    pickle.dump(train_ppl.get_variable('loss'), f)

with open('resnet/dataset', 'wb') as f:
    pickle.dump(ds, f)

In [None]:
test_tmp = (Pipeline()
    .add_namespace(np)
    .crop(src='images', dst='images', origin='random', shape=(SHAPE[2], SHAPE[1]))
    .to_array(src='images', dst='images', dtype='float32')
    .init_variable('metrics', default=None)
    .transpose(B('images'), axes=(0, 3, 1, 2), save_to=B('images'))
    .encode(B('labels'), labels_mapping, save_to=B('labels'))
    .init_model('dynamic', ResNet18, 'model', config={
                    'device': 'gpu:1', 'load/path': 'resnet/model.torch'
                })
    .predict_model('model', B('images'), fetches='proba', save_to=B('proba'))
    .gather_metrics('class', targets=B('labels'), predictions=B('proba'),
                    fmt='proba', axis=-1, save_to=V('metrics', mode='u'))
)

In [None]:
test_ppl = (load_ppl + test_tmp << ds.test)
test_ppl.run(64, n_epochs=1, bar=True)

In [None]:
val_metrics = test_ppl.get_variable('metrics')
print(val_metrics._confusion_matrix)

for m in ['specificity', 'sensitivity', 'accuracy', 'f1_score']:
    print(m, ':', val_metrics.evaluate(m))

In [None]:
example_ppl = (load_ppl + test_tmp << ds.test)
b = example_ppl.next_batch(64, shuffle=True)

In [None]:
i = 0
for i in range(len(b.images)):
    image = b.images[i].transpose((2, 1, 0))
    target = reverse_mapping[b.labels[i]]
    pred = reverse_mapping[b.proba[i].argmax()]


    plt.figure(figsize=(5, 10))
    plt.imshow(image / 255)
    plt.title(target + '     ' + pred, color='g' if target == pred else 'r')
    plt.show()