In [1]:
import sys
sys.path.insert(0, "../code")

In [2]:
import tqdm
import torch
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader, Sampler
from torch.utils.data.dataset import Dataset, Subset
import torch.utils.data.distributed as data_dist

from dataflow.transforms import TransformedDataset
from dataflow.datasets import get_trainval_datasets, read_img_in_db_with_mask

In [3]:
INPUT_PATH = "../input/"

In [4]:
import os

data_path = os.path.join(INPUT_PATH, "train_tiles")
csv_path = os.path.join(data_path, "tile_stats.csv")

train_folds = [0, 1, 3]
val_folds = [2, ]

train_ds, val_ds = get_trainval_datasets(data_path, csv_path, train_folds=train_folds, val_folds=val_folds, read_img_mask_fn=read_img_in_db_with_mask)


In [5]:
from ignite.engine import Engine
from ignite.metrics import VariableAccumulation, Average
from ignite.contrib.handlers import ProgressBar


class _Average(Average):
    
     def __init__(self, output_transform=lambda x: x, device=None):
        
        super(_Average, self).__init__(output_transform=output_transform, device=device)
        
        def _mean_op(a, x):

            if x.ndim > 1:
                x = x.sum(dim=0)
            
            return a + x
    
        self._op = _mean_op


In [6]:
from albumentations.pytorch import ToTensorV2


train_ds = TransformedDataset(train_ds, transform_fn=ToTensorV2())
train_loader = DataLoader(train_ds, shuffle=False, drop_last=False,
                          batch_size=16, num_workers=10, pin_memory=False)


In [9]:
def compute_mean_std(engine, batch):
    b, c, *_ = batch['image'].shape
    data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
    
    mean = torch.mean(data, dim=-1)
    mean2 = torch.mean(data ** 2, dim=-1)
    
    return {
        "mean": mean,
        "mean^2": mean2,
    }


compute_engine = Engine(compute_mean_std)

ProgressBar(desc="Compute Mean/Std").attach(compute_engine)

img_mean = _Average(output_transform=lambda output: output['mean'])
img_mean2 = _Average(output_transform=lambda output: output['mean^2'])

img_mean.attach(compute_engine, 'mean')
img_mean2.attach(compute_engine, 'mean2')

In [10]:
state = compute_engine.run(train_loader)

HBox(children=(IntProgress(value=0, description='Compute Mean/Std', max=390, style=ProgressStyle(description_w…

In [11]:
state.metrics

{'mean': tensor([-17.7050, -10.3331, -12.4229], dtype=torch.float64),
 'mean2': tensor([356.2868, 143.1767, 191.1010], dtype=torch.float64)}

In [12]:
state.metrics['mean2'] - state.metrics['mean'] ** 2

tensor([42.8202, 36.4036, 36.7713], dtype=torch.float64)

In [13]:
state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)

In [14]:
state.metrics['std']

tensor([6.5437, 6.0335, 6.0639], dtype=torch.float64)

Test

In [15]:
import numpy as np

n = 8
b = 12
c = 3
w = h = 64

true_data = np.arange(0, n * b * h * w * c, dtype='float64').reshape(n * b, c, h, w) - (n * b * c * w * h * 0.75)
true_data.shape

(96, 3, 64, 64)

In [16]:
mean = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).mean(axis=0)
mean2 = np.power(true_data, 2.0).transpose((0, 2, 3, 1)).reshape(-1, c).mean(axis=0)
std = true_data.transpose((0, 2, 3, 1)).reshape(-1, c).std(axis=0)
mean, mean2, std

(array([-299008.5, -294912.5, -290816.5]),
 array([2.05359015e+11, 2.02926315e+11, 2.00527169e+11]),
 array([340518.62237153, 340518.62237153, 340518.62237153]))

In [17]:
# train_loader = torch.arange(0, n * b * h * w * c, dtype=torch.float64).reshape(n, b, c, h, w) - (n * b * c * w * h * 0.75)

In [18]:
train_loader = torch.from_numpy(true_data).reshape(n, b, c, h, w)

In [19]:
train_loader.shape

torch.Size([8, 12, 3, 64, 64])

In [20]:
class _Average(Average):
    
     def __init__(self, output_transform=lambda x: x, device=None):
        
        super(_Average, self).__init__(output_transform=output_transform, device=device)
        
        def _mean_op(a, x):

            if x.ndim > 1:
                x = x.sum(dim=0)
            
            return a + x
    
        self._op = _mean_op


In [21]:
def compute_mean_std(engine, batch):
    b, c, *_ = batch.shape
    data = batch.reshape(b, c, -1).to(dtype=torch.float64)
    
    mean = torch.mean(data, dim=-1)
    mean2 = torch.mean(data ** 2, dim=-1)
    
    return {
        "mean": mean,
        "mean^2": mean2,
    }


compute_engine = Engine(compute_mean_std)

ProgressBar(desc="Compute Mean/Std").attach(compute_engine)

img_mean = _Average(output_transform=lambda output: output['mean'])
img_mean2 = _Average(output_transform=lambda output: output['mean^2'])

img_mean.attach(compute_engine, 'mean')
img_mean2.attach(compute_engine, 'mean2')

In [22]:
compute_mean_std(None, train_loader[0, ...])

{'mean': tensor([[-882688.5000, -878592.5000, -874496.5000],
         [-870400.5000, -866304.5000, -862208.5000],
         [-858112.5000, -854016.5000, -849920.5000],
         [-845824.5000, -841728.5000, -837632.5000],
         [-833536.5000, -829440.5000, -825344.5000],
         [-821248.5000, -817152.5000, -813056.5000],
         [-808960.5000, -804864.5000, -800768.5000],
         [-796672.5000, -792576.5000, -788480.5000],
         [-784384.5000, -780288.5000, -776192.5000],
         [-772096.5000, -768000.5000, -763904.5000],
         [-759808.5000, -755712.5000, -751616.5000],
         [-747520.5000, -743424.5000, -739328.5000]], dtype=torch.float64),
 'mean^2': tensor([[7.7914e+11, 7.7193e+11, 7.6475e+11],
         [7.5760e+11, 7.5048e+11, 7.4340e+11],
         [7.3636e+11, 7.2935e+11, 7.2237e+11],
         [7.1542e+11, 7.0851e+11, 7.0163e+11],
         [6.9478e+11, 6.8797e+11, 6.8119e+11],
         [6.7445e+11, 6.6774e+11, 6.6106e+11],
         [6.5442e+11, 6.4781e+11, 6.4123e

In [23]:
state = compute_engine.run(train_loader)

HBox(children=(IntProgress(value=0, description='Compute Mean/Std', max=8, style=ProgressStyle(description_wid…

In [24]:
state.metrics

{'mean': tensor([-299008.5000, -294912.5000, -290816.5000], dtype=torch.float64),
 'mean2': tensor([2.0536e+11, 2.0293e+11, 2.0053e+11], dtype=torch.float64)}

In [25]:
state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)

In [26]:
state.metrics['std']

tensor([340518.6224, 340518.6224, 340518.6224], dtype=torch.float64)

In [16]:
import random

from ignite.engine import Engine, Events


def update_model(engine, batch):
    return random.random()


trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = 0.01
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))

    
trainer.run(list(range(1000)), max_epochs=5)

Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
Epoch 1/5 : 200 - batch loss: 0.011986311293506247, lr: 0.01
Epoch 1/5 : 300 - batch loss: 0.2531812456307698, lr: 0.01
Epoch 1/5 : 400 - batch loss: 0.8978772717472793, lr: 0.01
Epoch 1/5 : 500 - batch loss: 0.44192087329825613, lr: 0.01
Epoch 1/5 : 600 - batch loss: 0.738644399068905, lr: 0.01
Epoch 1/5 : 700 - batch loss: 0.6703575663223107, lr: 0.01
Epoch 1/5 : 800 - batch loss: 0.1586780391733491, lr: 0.01
Epoch 1/5 : 900 - batch loss: 0.1917330500497959, lr: 0.01
Epoch 1/5 : 1000 - batch loss: 0.002969448472456837, lr: 0.01
Epoch 2/5 : 1100 - batch loss: 0.26629197951445027, lr: 0.01
Epoch 2/5 : 1200 - batch loss: 0.36764951402604584, lr: 0.01
Epoch 2/5 : 1300 - batch loss: 0.45918256295209015, lr: 0.01
Epoch 2/5 : 1400 - batch loss: 0.7598764350680366, lr: 0.01
Epoch 2/5 : 1500 - batch loss: 0.3932501190109595, lr: 0.01
Epoch 2/5 : 1600 - batch loss: 0.28595545931561117, lr: 0.01
Epoch 2/5 : 1700 - batch loss: 0.4217900

State:
	output: 0.6721766286931133
	batch: 999
	dataloader: <class 'list'>
	max_epochs: 5
	metrics: <class 'dict'>
	iteration: 5000
	epoch: 5