In [None]:
!wget https://solafune-dev-v1.s3.us-west-2.amazonaws.com/competitions/cloudmaskcompetition/train_mask.zip
# !wget https://solafune-dev-v1.s3.us-west-2.amazonaws.com/competitions/cloudmaskcompetition/sample.zip
!wget https://solafune-dev-v1.s3.us-west-2.amazonaws.com/competitions/cloudmaskcompetition/train_true_color.zip
# !wget https://solafune-dev-v1.s3.us-west-2.amazonaws.com/competitions/cloudmaskcompetition/evaluation_true_color.zip
!unzip -q train_mask.zip -d train_mask
# !unzip -q sample.zip -d sample
!unzip -q train_true_color.zip -d train_true_color
# !unzip -q evaluation_true_color.zip -d evaluation_true_color
!pip install rasterio
!pip install tifffile
!pip install transformers
!pip install accelerate
!pip install torchinfo
!pip install evaluate
!pip install peft

In [None]:
import os
import rasterio
import tifffile
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch import optim
import math
import seaborn as sns
from transformers import (
    MaskFormerImageProcessor,
    MaskFormerModel,
    MaskFormerConfig,
    MaskFormerForInstanceSegmentation
)
import evaluate
from sklearn.metrics import confusion_matrix
import torchinfo
from peft import LoraConfig, get_peft_model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'facebook/maskformer-swin-large-ade'

In [None]:
train_true = sorted(os.listdir('train_true_color'))
train_mask = sorted(os.listdir('train_mask'))

In [None]:
print(f'Data length: {len(train_true)}')

In [None]:
items = random.sample(range(len(train_true)), 6)
fig = plt.figure(figsize=(15, 10))
for i in range(6):
  with rasterio.open(f"train_true_color/train_true_color_{items[i]}.tif") as src:
    image = src.read()
    image = image / np.max(image)
    ax = fig.add_subplot(2, 3, i+1)
    ax.imshow(image.T)
    ax.set_title(f'image index: {items[i]}')
fig.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(15, 10))
for i in range(6):
  with rasterio.open(f"train_mask/train_mask_{items[i]}.tif") as src:
    image = src.read()
    image = image / (np.max(image) if np.max(image) != 0 else 1)
    ax = fig.add_subplot(2, 3, i+1)
    ax.imshow(np.logical_not(image.reshape(1000, 1000).T),  cmap='binary')
    ax.set_title(f'image index: {items[i]}')
fig.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(15, 10))
for i in range(6):
  with rasterio.open(f"train_true_color/train_true_color_{items[i]}.tif") as src:
    image = src.read()
    image = image / np.max(image)
    ax = fig.add_subplot(2, 3, i+1)
    ax.imshow(image.T)
    ax.set_title(f'image index: {items[i]}')
  with rasterio.open(f"train_mask/train_mask_{items[i]}.tif") as src:
    image = src.read()
    image = image / (np.max(image) if np.max(image) != 0 else 1)
    ax.imshow(np.logical_not(image.reshape(1000, 1000).T),  cmap='gray', alpha=0.3)
fig.tight_layout()
plt.show()

In [None]:
preprocessor = MaskFormerImageProcessor.from_pretrained(model_name)

In [None]:
class CloudsDataset(Dataset):
  def __init__(self, image_files, mask_files, splitting):
    self.image_files = image_files
    self.mask_files = mask_files
    self.splitting = splitting
  def __len__(self):
    return len(self.image_files) * 4 if self.splitting else len(self.image_files)

  def __getitem__(self, idx):
    image_file_idx = idx // 4
    sub_idx = idx % 4
    image = tifffile.imread(f"train_true_color/{self.image_files[image_file_idx]}").astype(np.float32)
    mask = tifffile.imread(f"train_mask/{self.mask_files[image_file_idx]}").astype(np.float32)
    image = torch.from_numpy(image).float()
    mask = torch.from_numpy(mask).float()

    image = image / torch.max(image) if torch.max(image) != 0 else 1
    image = image.view((3, 1000, 1000))
    mask = mask.view((1000, 1000))
    if self.splitting:

      if sub_idx == 0:
        image = image[:, :500, :500]
        mask = mask[:500, :500]
      if sub_idx == 1:
        image = image[:, 500:1000, :500]
        mask = mask[500:1000, :500]
      if sub_idx == 2:
        image = image[:, :500, 500:1000]
        mask = mask[:500, 500:1000]
      if sub_idx == 3:
        image = image[:, 500:1000, 500:1000]
        mask = mask[500:1000, 500:1000]

    return image, mask

In [None]:
def collate_fn(batch):
  inputs = list(zip(*batch))
  images = inputs[0]
  segmentation_maps = inputs[1]
  batch = preprocessor(
        images,
        segmentation_maps=segmentation_maps,
        return_tensors="pt",
        do_rescale=False,
        do_normalize=False,
        do_resize=False,
        ignore_index=-1,
      )
  batch['original_image'] = images
  batch['original_mask'] = segmentation_maps
  return batch

In [None]:
train_split_ratio = 0.8
eval_split_ratio = 1 - train_split_ratio
train_size = int(len(train_true) * train_split_ratio)
batch_size = 2

In [None]:
tmp_dataset = CloudsDataset(train_true, train_mask, False)
train_dataset = CloudsDataset(train_true[:train_size], train_mask[:train_size], True)
eval_dataset = CloudsDataset(train_true[train_size:], train_mask[train_size:], True)
tmp_dataloader = DataLoader(tmp_dataset, batch_size=1, drop_last=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, drop_last=True, collate_fn=collate_fn)

In [None]:
ratios = []
for idx, (_, mask) in enumerate(tmp_dataloader):
  ratios.append(torch.mean(mask).item())
sns.histplot(ratios, element='poly')
plt.show()

In [None]:
labels = ['not_clouds', 'clouds']
id2label = {idx: label for idx, label in enumerate(labels)}

In [None]:
torch.cuda.empty_cache()
model = MaskFormerForInstanceSegmentation.from_pretrained(
    model_name,
    id2label=id2label,
    ignore_mismatched_sizes=True
)

In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value", "key"],
)
model = get_peft_model(model, config).to(device)

In [None]:
freezed_layers = ['decoder.layers.5', 'class_predictor', 'mask_embedder']
for name, param in model.named_parameters():
  if any(s in name for s in freezed_layers):
    param.requires_grad = True

In [None]:
trainable_parameters = sum([p.numel() for p in model.parameters() if p.requires_grad])
total_parameters = sum([p.numel() for p in model.parameters()])
ratio = (trainable_parameters / total_parameters) * 100
print('Percent of trainable parameters: ', ratio)

In [None]:
batch = next(iter(train_dataloader))
outputs = model(
    pixel_values=batch["pixel_values"].to(device),
    mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
    class_labels=[labels.to(device) for labels in batch["class_labels"]],
)
print('init loss:', outputs.loss)

In [None]:
torchinfo.summary(model, [batch_size, 3, 500, 500], device=device)

In [None]:
metric_a = evaluate.load('mean_iou')
num_epochs = 10

In [None]:
learning_rates = []
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs*len(train_dataloader), power=1.3)
for i in range(8000):
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    scheduler.step()
plt.figure(figsize=(10, 6))
plt.plot(learning_rates, label='Learning Rate', linewidth=2)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('PolynomialLR')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=8000, power=1.3)

In [None]:
for epoch in range(num_epochs):
  print('Running epoch:', epoch+1)
  num_samples = 0
  running_loss = 0
  for idx, batch in enumerate(tqdm(train_dataloader)):

    torch.cuda.empty_cache()
    optimizer.zero_grad()

    outputs = model(
        pixel_values=batch["pixel_values"].to(device),
        mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
        class_labels=[labels.to(device) for labels in batch["class_labels"]],
    )

    loss = outputs.loss
    loss.backward()
    num_samples += batch_size
    running_loss += loss.item()

    if (idx+1) % 50 == 0:
      print(f'Loss at batch {idx+1}: {running_loss/(idx+1)}')

    optimizer.step()
    scheduler.step()

  print(f'Train loss at {epoch+1}:', running_loss/len(train_dataloader))
  tp, fp, tn, fn = 0, 0, 0, 0
  val_loss = 0.0
  for idx, batch in enumerate(tqdm(eval_dataloader)):
    pixel_values = batch["pixel_values"]

    torch.cuda.empty_cache()
    with torch.no_grad():
      outputs = model(
          pixel_values=batch["pixel_values"].to(device),
          mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
          class_labels=[labels.to(device) for labels in batch["class_labels"]],
      )
    val_loss += outputs.loss.item()

    original_images = batch['original_image']
    target_sizes = [(image.shape[1], image.shape[2]) for image in original_images]
    predicted_segmentation_maps = preprocessor.post_process_semantic_segmentation(outputs,
                                                                                  target_sizes=target_sizes)
    ground_truth_segmentation_maps = batch["original_mask"]

    metric_a.add_batch(references=ground_truth_segmentation_maps, predictions=predicted_segmentation_maps)
    ground_truth_segmentation_maps = torch.cat(ground_truth_segmentation_maps, dim=0)
    ground_truth_segmentation_maps = ground_truth_segmentation_maps.view(-1)
    ground_truth_segmentation_maps = ground_truth_segmentation_maps.tolist()
    predicted_segmentation_maps = torch.cat(predicted_segmentation_maps, dim=0)
    predicted_segmentation_maps = predicted_segmentation_maps.view(-1)
    predicted_segmentation_maps = predicted_segmentation_maps.tolist()
    cm = confusion_matrix(ground_truth_segmentation_maps, predicted_segmentation_maps)
    if cm.shape[0] != 1:
      tp += cm[1, 1]
      fp += cm[0, 1]
      tn += cm[0, 0]
      fn += cm[1, 0]
    else:
      tn += cm[0][0]


  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1 = 2 * precision * recall / (precision + recall)
  print("IoU:", metric_a.compute(num_labels = len(id2label), ignore_index = -1))
  print('Precision:', precision)
  print('Recall:', recall)
  print('F1 Score:', f1)
  print('Val Loss:', val_loss / len(eval_dataloader))
  print('-----------------------------------------------------')