In [1]:
# Run pre-training
!python3 ../code/pretraining.py --config="../configs/pretraining/config.yaml"

In [None]:
# # Run experiment
# !python3 ../code/segmentation.py --config="../configs/segmentation/config.yaml"

In [None]:
# # Run diffusion model
# !python3 ../code/diffusion.py --config="../configs/diffusion/config.yaml"

In [None]:
# # Evaluate experiment
# !python3 ../code/evaluate.py --config="../configs/segmentation/config.yaml"

In [None]:
# # Folds
# 0.6587 -- original (efficientnetv2_m)
# 0.6563 -- original (resnest101e)
# 0.7173 -- pretraining (efficientnetv2_m)
# 0.6870 -- pretraining (resnest101e)
# 0.7216 -- finetuning (efficientnetv2_m)
# 0.6922 -- finetuning (resnest101e)

# # Validation folder
# 0.6210 -- original (efficientnetv2_m)
# 0.6310 -- original (resnest101e)
# 0.6412 -- pretraining (efficientnetv2_m)
# 0.6450 -- pretraining (resnest101e)
# 0.6487 -- finetuning (efficientnetv2_m)
# 0.6521 -- finetuning (resnest101e)

In [None]:
# # Create and save synthetic contrail images
# !python3 ../code/synthesize.py --config="../configs/diffusion/config.yaml"

In [None]:
import sys
sys.path.append("../code/")

import matplotlib.pyplot as plt

from dataset import ContrailsPretrainingDataset
from utils import data_split, normalize_min_max, load_synthetic_metadata

df = data_split("../data/data_split.csv")
df_train = df[df.fold != 0]
df_valid = df[df.fold == 0]

df_synthetic_train = load_synthetic_metadata()
df_synthetic_train = df_synthetic_train[df_synthetic_train.fold == 0]

# train_dataset = ContrailsPretrainingDataset(df_train, "../data/pseudo-labels/predictions/", 8, 384, split="train")
valid_dataset = ContrailsPretrainingDataset(df_valid, df_synthetic_train, "../data/pseudo-labels/predictions/", 8, 384, split="validation")

In [None]:
sample = valid_dataset[0]

In [None]:
sample["mask"].max()

In [None]:
from einops import rearrange

fig, axs = plt.subplots(1, 2, figsize=(8, 4))

axs[0].imshow(normalize_min_max(rearrange(sample["image"], "c h w -> h w c")))
axs[1].imshow(sample["mask"][0])

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

meta = pd.read_csv("../data/synthetic/fold-0/metadata.csv")

i = np.random.choice(meta.shape[0])
image = cv2.imread(meta.iloc[i].image_path)
condition = cv2.imread(meta.iloc[i].condition_path, cv2.IMREAD_GRAYSCALE)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(condition)
axs[1].imshow(image)

In [None]:
condition.shape

In [None]:
import sys
sys.path.append("../../../mids-2023/latent-diffusion")
sys.path.append("../../../mids-2023/taming-transformers")
sys.path.append("../code")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml

from dataset import SemanticSynthesisDataset
from diffusion import DiffusionModule
from einops import rearrange
from functools import reduce
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

In [None]:
def normalize_image(img, permute=True):
    """Reverts a pre-processed image back to [0, 255]."""
    x = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
    if permute:
        x = 255. * rearrange(x.cpu().numpy(), "c h w -> h w c")
    else:
        x = 255. * x.cpu().numpy()
    return x.astype(np.uint8)


def inference(model, batch, inference_steps, eta, guidance_scale):
    seg = batch["segmentation"].float()
    bsz = seg.size(0)
    with torch.no_grad():
        seg = rearrange(seg, "b h w c -> b c h w")
        cond = model.get_learned_conditioning(seg)
        uncond = model.get_learned_conditioning(torch.zeros(*seg.shape).to(model.device))
        samples, _ = model.sample_log(
            cond=cond, 
            batch_size=bsz, 
            ddim=True,
            ddim_steps=inference_steps, 
            eta=eta,
            unconditional_guidance_scale=guidance_scale,
            unconditional_conditioning=uncond
        )
        samples = model.decode_first_stage(samples)
    return seg, samples


def inference_batch(model, loader, inference_steps, eta, guidance_scale):
    conditions = []
    outputs = []
    for batch in loader:
        batch = {"segmentation": batch["segmentation"].to(model.device)}
        _, samples = inference(
            model, 
            batch, 
            inference_steps=inference_steps, 
            eta=eta,
            guidance_scale=guidance_scale
        )
        for i, s in enumerate(samples):
            outputs.append(normalize_image(s))
            conditions.append(batch["segmentation"][i].cpu().numpy())
    outputs = np.stack(outputs, axis=0)
    conditions = np.stack(conditions, axis=0)
    return outputs, conditions

In [None]:
# Load trained model
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(device)

inference_steps = 500
guidance_scale = 1.2
eta = 1.0
n = 16
folds = [0]

with open("../configs/diffusion/config.yaml", "rb") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

for fold in folds:
    checkpoint_path = f"../models/checkpoints/semantic_synthesis__fold_{fold}.ckpt"

    model = DiffusionModule(config["model"])
    model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
    m = model.model.to(device)
    m.eval()

    df = pd.read_csv("../data/data_split.csv")
    df_train = df[(df.fold != fold) & (df.split != "validation")].iloc[:16]
    dataset = SemanticSynthesisDataset(df_train, cond_drop_rate=0., split="train")
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    inference_outputs = inference_batch(
        m,
        dataloader,
        inference_steps=inference_steps,
        eta=eta,
        guidance_scale=guidance_scale
    )

    # output_dir = f"/home/romainlhardy/mids-2023/data/steatosis/synthetic-outputs/{diffusion_family}/fold-{fold}/checkpoint-{checkpoint}"
    # os.makedirs(output_dir, exist_ok=True)
    # os.makedirs(f"{output_dir}/images", exist_ok=True)

    # save_outputs(inference_outputs, output_dir, diffusion_family)

    # del model, dataset, dataloader, inference_outputs
    # torch.cuda.empty_cache()
    # gc.collect()

In [None]:
outputs, conditions = inference_outputs
fig, axs = plt.subplots(16, 2, figsize=(8, 16 * 4))
for i in range(16):
    axs[i, 0].imshow((conditions[i] * np.array([0, 60, 120, 180, 240])[None, None, :]).sum(axis=-1))
    axs[i, 1].imshow(outputs[i])

In [None]:
import sys
sys.path.append("../code/")

import os
import yaml

from evaluate import load_model


with open("../configs/segmentation/config.yaml", "rb") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = os.path.join(config["output_dir"], f"backbone_{config['model']['encoder']}__fold_0-v4.ckpt")
model = load_model(config["model"], checkpoint_path)

In [None]:
import sys
sys.path.append("../code")

import pandas as pd

from dataset import SemanticSynthesisDataset

# model = load_base_model("../configs/diffusion/model.yaml")
df = pd.read_csv("../data/data_split.csv")
dataset = SemanticSynthesisDataset(df, n_labels=5, cond_drop_rate=0.)

In [None]:
sample = dataset[2]

In [None]:
sample["image"].min(), sample["image"].max()

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 5, figsize=(24, 4))

for i in range(5):
    axs[i].imshow(sample["segmentation"][..., i])

In [None]:
plt.imshow(sample["image"])

In [None]:
import sys
sys.path.append("../code/")

import timm
import torch
import torch.utils.checkpoint as checkpoint
from nextvit import nextvit_base

m = timm.create_model("timm/convnextv2_base.fcmae_ft_in22k_in1k_384", pretrained=True)

In [None]:
m.pretrained_cfg

In [None]:
import torch.nn as nn
stages = [
    nn.Identity(),
    m.stem,
    m.stages[0],
    m.stages[1],
    m.stages[2],
    m.stages[3]
]

In [None]:
m

In [None]:
len(m.stages)

In [None]:
x = torch.randn(2, 3, 256, 256)

for block in stages:
    x = block(x)
    print(x.shape)

In [None]:
import timm
import torch
import torch.nn as nn 

import sys
sys.path.append("../code/")
from encoder import ConvNeXtEncoder, NextViTEncoder, EfficientNetEncoder, ResNestEncoder
from model import Unet

# m = NextViTEncoder("nextvit_base")
# m = timm.create_model("convnextv2_tiny.fcmae_ft_in22k_in1k", pretrained=True)
# m = ConvNeXtEncoder("convnextv2_base.fcmae_ft_in22k_in1k_384", timesteps=1)
# m = EfficientNetEncoder("tf_efficientnetv2_m.in21k_ft_in1k", stage_idxs=[2, 3, 5])
# m = EfficientNetEncoder("tf_efficientnetv2_s.in21k_ft_in1k")
# m = ResNestEncoder("resnest101e.in1k", timesteps=1)
m = Unet(
    encoder_name="convnextv2_base.fcmae_ft_in22k_in1k_384",
    decoder_use_batchnorm=True,
    decoder_channels=[512, 256, 128, 64, 32],
    decoder_attention_type="scse",
    classes=1,
    activation=None,
    aux_params=None,
    timesteps=1
)

In [None]:
import torch

m(torch.randn(2, 1, 3, 384, 384)).shape
# for x in m(torch.randn(2, 1, 3, 384, 384)):
#     print(x.shape)

In [None]:
m.attention1.in_channels

In [None]:
class SCSEModule(nn.Module):
    def __init__(self, in_channels, timesteps=8, reduction=16):
        super().__init__()
        self.in_channels = in_channels
        self.timesteps = timesteps
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
        self.tSE = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(timesteps, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(1, timesteps, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        h, w = x.size()[-2:]
        x = x * self.cSE(x) + x * self.sSE(x)
        x = x.view(-1, self.timesteps, self.in_channels, h, w)
        x = torch.sum(x * self.tSE(x), dim=1)
        return x

    
scse = SCSEModule(24)

scse(torch.randn(32, 24, 192, 192)).shape

In [None]:
nn.AdaptiveAvgPool2d(1)(torch.randn(1, 24, 192, 192)).shape

In [None]:
nn.Conv2d(24, 1, 1)(torch.randn(1, 24, 192, 192)).shape

In [None]:
x = torch.randn(1, 3, 256, 256)

x = m.stem(x)
print(x.shape)
for stage in m.stages:
    x = stage(x)
    print(x.shape)

In [None]:
m

In [None]:
len(m.stages)

In [None]:
import torch

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../models/sam-encoders/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)

In [None]:
from segment_anything.modeling import image_encoder

In [None]:
import sys
sys.path.append("../code")

from encoder import SAMEncoder
from model import Unet

In [None]:
m = Unet()

In [None]:
import torch

m(torch.randn(1, 3, 384, 384)).shape

In [None]:
[f.shape for f in features]

In [None]:
state_dict = torch.load("../models/sam-encoders/sam_vit_b_01ec64.pth")

In [None]:
m = image_encoder.ImageEncoderViT(
    img_size=384,
    embed_dim=768,
    depth=12,
    num_heads=12,
    global_attn_indexes=[2, 5, 8, 11],
    patch_size=16,
    use_abs_pos=False
)

In [None]:
state_dict = {
    k.replace("image_encoder.", ""): v
    for k, v in state_dict.items()
    if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder")
}

In [None]:
m.load_state_dict(state_dict, strict=False)

In [None]:
m.patch_embed(torch.randn(1, 3, 384, 384)).shape

In [None]:
import math

int(math.log(16, 2))

In [None]:
m.pos_embed

In [None]:
with torch.no_grad():
    m(torch.randn(1, 3, 384, 384))

In [None]:
import sys
sys.path.append("../code")

import ast
import argparse
import cv2
import gc
import numpy as np
import os
import pandas as pd
import torch
import yaml

from dataset import ContrailsDataset
from segmentation import SegmentationModule
from torchmetrics import Dice
from torch.utils.data import DataLoader
from utils import data_split, load_record, FOLDS

df = data_split("../data/data_split.csv")
# df_analysis = pd.read_csv("../evaluation/analysis_metadata.csv")

# len(df_analysis)

In [None]:
df = data_split("../data/data_split.csv")
df_train = df[df.fold != 0]
df_valid = df[df.fold == 0]
train_dataset = ContrailsDataset(df_train, 5, 384, split="train")
valid_dataset = ContrailsDataset(df_valid, 5, 384, split="validation")

In [None]:
valid_dataset[0]["image"].shape

In [None]:
i = 5
row = df_analysis.iloc[i]
prob = np.load(f"../evaluation/predictions/{row.record_id}.npy")
x, y = load_record(df[df.record_id == row.record_id].iloc[0].record_path)

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(x)
axs[1].imshow(y)
axs[2].imshow(prob[0])

In [None]:
from utils import load_record

x, y = load_record(df.iloc[0].record_path)

In [None]:
dataset = ContrailsDataset(df)

sample = dataset[0]

In [None]:
sample["image"].shape

In [None]:
import sys
sys.path.append("../code")

from encoder import Encoder 
from model import Unet

# m = Encoder("tf_efficientnetv2_l_in21ft1k", stage_idxs=(2, 3, 5))
m = Unet(
    encoder_name="tf_efficientnetv2_l_in21ft1k",
    decoder_use_batchnorm=True,
    decoder_channels=(256, 128, 64, 32, 16),
    decoder_attention_type=None,
    classes=1,
    activation=None,
    aux_params=None
)


In [None]:
import torch

# for x in m(torch.randn(1, 3, 384, 384)):
#     print(x.shape)

m(torch.randn(1, 3, 256, 256))

In [None]:
f = m.encoder.forward(torch.randn(1, 3, 256, 256))
f

In [None]:
f = f[1:][::-1]

In [None]:
m.decoder.blocks[0](f[0], f[1])

In [None]:
import inspect

print(inspect.getsource(m.decoder.forward))

In [None]:
def forward(features):
    features = features[1:]  # remove first skip with same spatial resolution
    features = features[::-1]  # reverse channels to start from head of encoder

    head = features[0]
    skips = features[1:]

    x = m.decoder.center(head)
    for i, decoder_block in enumerate(m.decoder.blocks):
        skip = skips[i] if i < len(skips) else None
        print(x.shape, skip.shape)
        x = decoder_block(x, skip)

    return x

In [None]:
forward(f)

In [None]:
m.model.stem(torch.randn(1, 3, 384, 384)).shape

In [None]:
m.model.stem