# Lithology prediction

Firtsly, nesessary imports.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
import shutil
import dill
import datetime

sys.path.insert(0, os.path.join("..", "..", ".."))

from petroflow import WellDataset, WS
from petroflow.batchflow.models.torch import UNet
from petroflow.batchflow import DatasetIndex, FilesIndex, ImagesBatch, Pipeline, V, B, W, L, R, P

from utils import *

Define constants that will be used at train and inference stages.

In [None]:
BATCH_SIZE = 16
N_CROPS = 4

CROPS_BATCH = BATCH_SIZE * N_CROPS

N_EPOCH = 500
LENGTH = 0.1
SHAPE = (3, int(2500 * LENGTH), 250)

FILTERS = ((2 ** np.arange(4)) * 4).tolist()

PATH = '/notebooks/data/september_dataset/core_photo/*/*'

Remove wells without `core_lithology`

In [None]:
ds = WellDataset(index=FilesIndex(path=PATH, dirs=True))
ds = filter_dataset(ds)

Get list of classes and construct mapping from classes to labels

In [None]:
classes = get_classes(ds)
reverse_mapping = dict(enumerate(classes))
mapping = {value: key for key, value in reverse_mapping.items()}

Preload all the data to make training process faster.

In [None]:
%%time

load_ppl = (ds.p
    .create_segments(src='samples', connected=True)
    .create_segments(src='core_lithology', connected=True)
    .drop_short_segments(LENGTH)
    .load_core(pixels_per_cm=25)
    .next_batch(len(ds), n_epochs=1)
)

ds = build_dataset(load_ppl)
ds.split(shuffle=42)

## Train model

In [None]:
crop_template = (Pipeline()
       .add_namespace(np)
       .copy()
       .random_crop(length=LENGTH, n_crops=N_CROPS)
       .update(WS('core_lithology')['CLASS'],
               WS('core_lithology')[['FORMATION', 'GRAIN']].apply(concat, axis=1).ravel())
       .create_mask(src='core_lithology', column='CLASS', mapping=mapping, mode='core')
       .update(B('core'), WS('core_dl').ravel())
       .update(B('masks'), WS('mask').ravel())
       .array(B('core'), save_to=B('core'))
       .array(B('masks'), save_to=B('masks'))
       .reshape(B('masks'), (-1, 1, 250), save_to=B('masks'))
)

In [None]:
augmentation_template = (
    Pipeline()
    .update(B().index, L(DatasetIndex)(L(len)(B('core'))))
    .rebatch(CROPS_BATCH, batch_class=ImagesBatch, components=('core', 'masks'))
    .add_namespace(np)
    .to_pil(src='core', dst='core')
    .scale(src='core', dst='core', preserve_shape=True, factor=P(R('uniform', low=1, high=1.5)))
    .cutout(shape=P(R('randint', low=[200, 0], high=[250, 40])),
            origin=P(R('uniform', size=2)), color=0,
            src='core', dst='core', p=0.5)
    .multiply(src='core', dst='core', multiplier=P(R('uniform', low=0.7, high=1.2)))
    .to_array(src='core', dst='core', dtype='float32')
    .transpose(B('core'), axes=(0, 3, 1, 2), save_to=B('core'))
)

In [None]:
model_config = {
    "body/encoder/num_stages": len(FILTERS[:-1]),
    'body/encoder/blocks/filters': FILTERS[:-1],
    "body/decoder/blocks/filters": FILTERS[-2::-1],
    "initial_block/inputs": "inputs",
    "inputs/inputs/shape": SHAPE,
    'inputs/masks/shape': (len(mapping), 1, SHAPE[1]),
    "head": dict(layout="c",
                 kernel_size=(SHAPE[2], 1), padding='valid', conv=dict(bias=True)),
    "loss": "ce",
    "optimizer": {"name": "Adam", "lr": 0.01},
    "output": 'proba',
    'device': 'gpu:1',
}
        
train_template = (Pipeline()
    .init_variable('loss_history', default=[])
    .init_model('dynamic', UNet, 'model', model_config)
    .train_model('model', B('core').astype('float32'), B('masks'),
                 fetches='loss', save_to=V('loss_history', mode='a'))
)

train_ppl = (crop_template + augmentation_template + train_template) << ds.train

In [None]:
train_ppl.run(BATCH_SIZE, n_epochs=N_EPOCH, bar=True, bar_desc=W(V('loss_history')[-1]))

Dump results

In [None]:
SAVE_TO = './models/unet_' + str(datetime.datetime.now()).replace(' ', '_')
dump_results(train_ppl, SAVE_TO)

In [None]:
model_path = get_last_model_path('./models/unet_*')
print(model_path)

Load loss

In [None]:
with open(os.path.join(model_path, 'loss.pkl'), 'rb') as f:
    loss = dill.load(f)

In [None]:
import pandas as pd
plt.plot(loss)
plt.plot(pd.Series(loss).rolling(window=100).mean())

## Inference

In [None]:
def test_template(length, random_crop=False, step=None, n_crops=None):
    step = step or length
    n_crops = n_crops or 4

    if random_crop:
        ppl = Pipeline().random_crop(length=length, n_crops=n_crops)
    else:
        ppl = Pipeline().crop(length=length, step=step)

    ppl = ppl + (Pipeline()
        .add_namespace(np)
        .copy()
        .add_components(('core', 'masks'))
        .update(WS('core_lithology')['CLASS'], WS('core_lithology')[['FORMATION', 'GRAIN']].apply(concat, axis=1).ravel())
        .create_mask(src='core_lithology', column='CLASS', mapping=mapping, mode='core').update(B('core'), WS('core_dl').ravel())
        .update(B('masks'), WS('mask').ravel())
        .array(B('core'), save_to=B('core'))
        .array(B('masks'), save_to=B('masks'))
        .transpose(B('core'), axes=(0, 3, 1, 2), save_to=B('core'))
        .reshape(B('masks'), (-1, 1, 250), save_to=B('masks'))
        .update(B().index, L(DatasetIndex)(B('core').shape[0]))
        .rebatch(32, components=('core', 'masks'), batch_class=ImagesBatch)
        .init_variable('metrics')
        .add_namespace(np)
        .init_model('dynamic', UNet, 'model', config={
                        'device': 'gpu:1', 'load/path': 'unet.torch'
                    })
        .predict_model('model', B('core').astype('float32'), fetches='proba', save_to=B('proba'))
        .gather_metrics('classification', targets=B('masks').reshape(-1),
                        predictions=B('proba').argmax(1).reshape(-1),
                        fmt='labels', num_classes=len(mapping), save_to=V('metrics', mode='u'))
    )
    return ppl

test_ppl = test_template(LENGTH, random_crop=False) << ds.test

In [None]:
test_ppl.run(10, bar=True)

dump_metrics(test_ppl, os.path.join(SAVE_TO, 'metrics.pkl'))

In [None]:
with open(os.path.join(model_path, 'metrics.pkl'), 'rb') as f:
    metrics = dill.load(f)

F1-scores

In [None]:
for i, item in enumerate(metrics.evaluate('f1_score', agg='mean', multiclass=None)):
    print(reverse_mapping[i], item)

Some examples

In [None]:
batch = (test_template(LENGTH, random_crop=True) << ds.test).next_batch(1)

In [None]:
plot_examples(batch, reverse_mapping)