In [1]:
import warnings
warnings.filterwarnings("ignore")

import gc
import timm
import os
import glob

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
import torchvision.transforms as T
import yaml

In [2]:
import shutil
import os

!mkdir -p /root/.cache/huggingface/hub

try:
    shutil.copytree('/kaggle/input/preresnet/resnest101e/models--timm--resnest101e.in1k', '/root/.cache/huggingface/hub/models--timm--resnest101e.in1k')

except FileExistsError:
    pass

os.listdir('/root/.cache/huggingface/hub/models--timm--resnest101e.in1k')

# src = pathlib.Path(r"preresnet/resnet101e").as_posix()
# shutil.copytree(src, "/root/.cache/huggingface/hub")

['refs', 'blobs', 'snapshots']

In [3]:
import yaml

batch_size = 32
num_workers = 1
THR = 0.5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'
data_root = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/test/'
submission = pd.read_csv(os.path.join(data, 'sample_submission.csv'), index_col='record_id')

In [4]:
filenames = os.listdir(data_root)
test_df = pd.DataFrame(filenames, columns=['record_id'])
test_df['path'] = data_root + test_df['record_id'].astype(str)

In [5]:
class ContrailsDataset(torch.utils.data.Dataset):

    def __init__(self, df, image_size=256, train=True):
        
        self.df = df
        self.trn = train
        self.df_idx: pd.DataFrame = pd.DataFrame({'idx': os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/test')})
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:

            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        
        N_TIMES_BEFORE = 4

        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)

        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        img = false_color[..., N_TIMES_BEFORE]

        return img
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        data = self.read_record(con_path)    
        
        img = self.get_false_color(data)
        
        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
        
        if self.image_size != 256:
            img = self.resize_image(img)
        
        img = self.normalize_image(img)
        
        image_id = int(self.df_idx.iloc[index]['idx'])
            
        return img.float(), torch.tensor(image_id)
    
    def __len__(self):
        return len(self.df)

In [6]:
def rle_encode(x, fg_val=1):
    dots = np.where(
        x.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def list_to_string(x):

    if x:
        s = str(x).replace("[", "").replace("]", "").replace(",", "")
    else:
        s = '-'
    return s

In [7]:
class EncoderModule(nn.Module):
    def __init__(self):
        super(EncoderModule, self).__init__()
        self.encoder = timm.create_model('resnest101e', pretrained=True)
        self.stages =  nn.ModuleList([
            nn.Identity(),
            nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.act1),
            nn.Sequential(self.encoder.maxpool, self.encoder.layer1),
            self.encoder.layer2,
            self.encoder.layer3,
            self.encoder.layer4,
        ])

    def forward(self, x):
        features = []
        for feature in self.stages:
            x = feature(x)
            features.append(x)
        return features


class UnetModule(nn.Module):
    def __init__(self, upmapping="upconv"):
        super(UnetModule, self).__init__()
        self.upmapping = upmapping
        self.encoder = EncoderModule()
        self.dec_in_c = [2048, 256, 128, 64, 32]
        self.dec_out_c = [256, 128, 64, 32, 16]
        self.skip_c = [1024, 512, 256, 128, 0]

        self.module_list = nn.ModuleList()
        for i in range(len(self.dec_in_c)):
            if upmapping == 'upsample':
                act_channels = self.dec_in_c[i]
            else:
                act_channels = self.dec_in_c[i]//2
            self.module_list.append(nn.ModuleList(
                [self.expanding_unit(self.dec_in_c[i], self.dec_in_c[i]//2, 2, 0),
                self.base_unit(act_channels + self.skip_c[i], self.dec_out_c[i], 3, 1)]))

        self.final_conv = nn.Conv2d(self.dec_out_c[-1], 1, kernel_size=1)

    def base_unit(self, in_c, out_c, f, p):
        return nn.Sequential(
            nn.Dropout(0.3),
            nn.Conv2d(in_c, out_c, kernel_size=f, padding=p),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Conv2d(out_c, out_c, kernel_size=f, padding=p),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU()
        )

    def expanding_unit(self, in_c, out_c, f, p):
        if self.upmapping == 'upsample':
            return nn.Upsample(scale_factor=2, mode='nearest')
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, f, padding=p, stride=2),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(),
                nn.ConvTranspose2d(out_c, out_c, 1, stride=1, padding=0),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU()
            )


    def center_crop(self, encoder_out, decoder_out):
        crop_dims = ((encoder_out.size(2) - decoder_out.size(2))//2,
                    (encoder_out.size(3) - decoder_out.size(3))//2)
        cropped_encoder_out = encoder_out[:, :,
                    crop_dims[0]:crop_dims[0] + decoder_out.size(2),
                    crop_dims[1]: crop_dims[1] + decoder_out.size(3)]
        return cropped_encoder_out


    def forward(self, x):
        encoder = self.encoder(x)
        features, x = encoder[1:][:-1][::-1], encoder[1:][-1]
        for i, module in enumerate(self.module_list):
            x = module[0](x)
            if i != len(self.module_list) - 1:
                crop4 = self.center_crop(features[i], x)
                x = torch.cat([x, crop4], 1)
            x = module[1](x)
        x = self.final_conv(x)
        return x

In [8]:
class LightningModule(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = UnetModule()

    def forward(self, batch):
        return self.model(batch)

In [9]:
MODEL_PATH = "/kaggle/input/model-3/"

In [10]:
test_ds = ContrailsDataset(
        test_df,
        train = False
)
 
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers = num_workers)

In [11]:
gc.enable()

all_preds = {}

for i, model_path in enumerate(glob.glob(MODEL_PATH + '*.ckpt')):
    print(model_path)
    model = LightningModule().load_from_checkpoint(model_path)
    model.to(device)
    model.eval()

    model_preds = {}
    
    for _, data in enumerate(test_dl):
        images, image_id = data
    
        images = images.to(device)

        with torch.no_grad():
            predicted_mask = model(images[:, :, :, :])

        predicted_mask = torch.sigmoid(predicted_mask).cpu().detach().numpy()
                
        for img_num in range(0, images.shape[0]):
            current_mask = predicted_mask[img_num, :, :, :]
            current_image_id = image_id[img_num].item()
            model_preds[current_image_id] = current_mask
    all_preds[f"f{i}"] = model_preds
    
    del model    
    torch.cuda.empty_cache()
    gc.collect()

/kaggle/input/model-3/model_0_dice_score0.6291.ckpt


In [12]:
for index in submission.index.tolist():
    
    for i in range(len(glob.glob(MODEL_PATH + '*.ckpt'))):

        if i == 0:
            
            predicted_mask = all_preds[f"f{i}"][index]
            
        else:
            predicted_mask += all_preds[f"f{i}"][index]

    predicted_mask = predicted_mask / len(glob.glob(MODEL_PATH + '*.ckpt'))
    predicted_mask_with_threshold = np.zeros((256, 256))
    predicted_mask_with_threshold[predicted_mask[0, :, :] < THR] = 0
    predicted_mask_with_threshold[predicted_mask[0, :, :] > THR] = 1
    submission.loc[int(index), 'encoded_pixels'] = list_to_string(rle_encode(predicted_mask_with_threshold))

In [13]:
submission

Unnamed: 0_level_0,encoded_pixels
record_id,Unnamed: 1_level_1
1000834164244036115,40965 3 41222 4 41479 5 41737 5 41994 5 42253 ...
1002653297254493116,-


In [14]:
submission.to_csv('submission.csv')