In [None]:
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
from matplotlib import pyplot as plt
from utils.imageIO import read_img

In [None]:
BATCH_SIZE = 32
SIZE = 256
EPOCHS = 25

TUNE_BATCH_SIZE = 1
TUNE_SIZE = 1024
TUNE_EPOCHS = 2


path_to_train = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-train-data.xlsx'
path_to_test = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-test-data.xlsx'
path_to_models = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\resources'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

In [None]:
model = smp.Unet(
        "timm-efficientnet-b3",
        activation='sigmoid',
        in_channels=1,
        classes=3
    ).to(device)

In [None]:
model(torch.randn(1, 1, SIZE, SIZE).to(device)).shape

In [None]:
# https://www.kaggle.com/code/fatemehfarnaghizadeh/pix2pix-gan

class CNNBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):

    def __init__(self, in_channels=4, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 4, stride=2, padding_mode='reflect', padding=1),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
                
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature==features[-1] else 2)
            )
            in_channels = feature

        self.model = nn.Sequential(*layers)
        self.final = nn.Conv2d(in_channels, 1, 4, stride=1, padding=1, padding_mode='reflect')

    def forward(self, x, y):
        input = torch.cat([x, y], dim=1)
        x = self.initial(input)
        x = self.model(x)
        
        return torch.sigmoid(self.final(x))
    
discriminator = Discriminator().to(device)
discriminator(torch.randn(2, 3, SIZE, SIZE).to(device), torch.randn(2, 1, SIZE, SIZE).to(device)).shape

In [None]:
train_df = pd.read_excel(path_to_train, index_col=0)
train_df.head()

In [None]:
test_df = pd.read_excel(path_to_test, index_col=0)
test_df.head()

In [None]:
def transform_train(SIZE):
    return A.Compose(
        transforms=[
            A.RandomCrop(SIZE, SIZE, p=1),
            ToTensorV2(p=1)
        ],
        additional_targets={'target_image': 'image'}
    )

def transform_valid(SIZE):
    return A.Compose(
        transforms=[
            A.RandomCrop(SIZE, SIZE, p=1),
            ToTensorV2(p=1)
        ],
        additional_targets={'target_image': 'image'}
    )

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, df, transforms):
        super().__init__()
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):

        gray_image = read_img(self.df.iloc[index, 1], rgb=False)
        image = read_img(self.df.iloc[index, 0])

        transformed = self.transforms(image=gray_image, target_image=image)
        gray_image, image = transformed['image'] / 255, transformed['target_image'] / 255

        return gray_image, image

In [None]:
train_datasets = ColorizationDataset(train_df, transforms=transform_train(SIZE))
valid_datasets = ColorizationDataset(test_df, transforms=transform_valid(SIZE))

In [None]:
x, y = train_datasets[0]
x.shape, y.shape

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(x[0].cpu().numpy(), cmap='gray');
axs[1].imshow(y.permute(1,2,0).cpu().numpy());

In [None]:
def collate_fn(batch):
    inputs, targets = zip(*batch)

    inputs = torch.stack(inputs).to(dtype=torch.float, device=device)
    targets = torch.stack(targets).to(dtype=torch.float, device=device)

    return inputs, targets

In [None]:
train_loader = DataLoader(
    train_datasets,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

valid_loader = DataLoader(
    valid_datasets,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)

In [None]:
from BlissLearn.BlissLearner import BlissColorizationLearner
from BlissLearn.BlissCallbacks.Callbacks import ColorizationMetricsCallback, PrintCriteriaCallback
from utils.metrics import accuracy

In [None]:
callbacks = [
    ColorizationMetricsCallback(
        common_generator_metrics={'MAE': nn.L1Loss()},
        common_discriminator_metrics={'Accuracy': accuracy}
    ),
    PrintCriteriaCallback()
]

In [None]:
learner = BlissColorizationLearner(
    generator=model,
    discriminator=discriminator,
    generator_loss_function=nn.L1Loss(),
    discriminator_loss_function=nn.BCELoss(),
    generator_optimizer_class=optim.Adam,
    generator_optimizer_kwargs={'lr': 0.0002},
    discriminator_optimizer_class=optim.Adam,
    discriminator_optimizer_kwargs={'lr': 0.0002/1.5},
    train_dataloader=train_loader,
    test_dataloader=valid_loader,
    callbacks=callbacks,
    alpha=0.08,
)

In [None]:
learner.fit(EPOCHS)

In [None]:
train_datasets = ColorizationDataset(train_df, transforms=transform_train(TUNE_SIZE))
valid_datasets = ColorizationDataset(test_df, transforms=transform_valid(TUNE_SIZE))

In [None]:
train_loader = DataLoader(
    train_datasets,
    batch_size=TUNE_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

valid_loader = DataLoader(
    valid_datasets,
    batch_size=TUNE_BATCH_SIZE,
    collate_fn=collate_fn,
)

In [None]:
tune_learner = BlissColorizationLearner(
    generator=model,
    discriminator=discriminator,
    generator_loss_function=nn.L1Loss(),
    discriminator_loss_function=nn.BCELoss(),
    generator_optimizer_class=optim.Adam,
    generator_optimizer_kwargs={'lr': 0.0002/5},
    discriminator_optimizer_class=optim.Adam,
    discriminator_optimizer_kwargs={'lr': 0.0002/5},
    train_dataloader=train_loader,
    test_dataloader=valid_loader,
    callbacks=callbacks,
    alpha=0.1,
    batches_to_validate=500
)

In [None]:
tune_learner.fit(TUNE_EPOCHS)

In [None]:
torch.save(model.state_dict(), path_to_models + r"\ColorUnetEffNet.pkl")

In [None]:
torch.save(discriminator.state_dict(), path_to_models + r"\Dics1EffNet.pkl")

In [None]:
x, y = valid_datasets[10]
x.shape, y.shape

model.eval()

with torch.inference_mode():
    preds = model(torch.unsqueeze(x.to(device), 0))

fig, axs = plt.subplots(1, 3, figsize=(18, 6))

axs[0].imshow(x[0].cpu().numpy(), cmap='gray');
axs[1].imshow(y.permute(1,2,0).cpu().numpy());
axs[2].imshow(preds[0].permute(1,2,0).cpu().numpy());