In [None]:
!pip install pytorch-lightning > /dev/null 2>&1
!pip install einops > /dev/null 2>&1
!pip install timm > /dev/null 2>&1

In [None]:
!rm -rf MixformerFromScratch
!git clone https://github.com/reeWorlds/MixformerFromScratch.git
!pip install -e "MixformerFromScratch"

import site
site.main()

In [None]:
if False:
  import os
  os._exit(0)

In [None]:
import torch
import pytorch_lightning as pl
import numpy as np
import os
import gc

import matplotlib.pyplot as plt

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from Mixformer import st1_target
from Mixformer import st1_search

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage1_SimpleSearchPart'

In [None]:
data_folder_path = data_prefix

train_patches_nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
valid_path_num = 10

train_search, train_class, train_out = None, None, None
valid_search, valid_class, valid_out = None, None, None

def get_tensor_by_path(file_path, size, shape, dtype):
  mmapped_array = np.memmap(file_path, dtype=dtype, mode='r', shape=(size,))
  tensor = torch.from_numpy(mmapped_array)
  return tensor.reshape(*shape)

def get_data_by_num(path_num):
  search_path = os.path.join(data_folder_path, f'patch{path_num}_search.bin')
  search_size = 10000 * 64 * 64 * 3
  search_tensor = get_tensor_by_path(search_path, search_size, (10000, 64, 64, 3), np.float32)
  class_path = os.path.join(data_folder_path, f'patch{path_num}_class.bin')
  class_size = 10000
  class_tensor = get_tensor_by_path(class_path, class_size, (10000,), np.uint8)
  class_tensor = class_tensor.int()
  out_path = os.path.join(data_folder_path, f'patch{path_num}_output.bin')
  out_size = 10000 * 64 * 64
  out_tensor = get_tensor_by_path(out_path, out_size, (10000, 64, 64), np.uint8)
  out_tensor = out_tensor.float() / 255.0
  return search_tensor, class_tensor, out_tensor

list_s, list_c, list_o = [], [], []

for patch_num in train_patches_nums:
  s, c, o = get_data_by_num(patch_num)
  list_s.append(s)
  list_c.append(c)
  list_o.append(o)
  print(f'Finished patch_num = {patch_num}')

train_search = torch.cat(list_s, dim=0)
train_class = torch.cat(list_c, dim=0)
train_out = torch.cat(list_o, dim=0)

valid_search, valid_class, valid_out = get_data_by_num(valid_path_num)
gc.collect()

print(f'train data shapes are s:{train_search.shape}, c:{train_class.shape}, o:{train_out.shape}')
print(f'valid data shapes are s:{valid_search.shape}, c:{valid_class.shape}, o:{valid_out.shape}')

In [None]:
def plot_image(searches, outs, index):
  plt.clf()
  img_search = searches[index]
  img_search_np = img_search.numpy()
  img_out = outs[index]
  img_out_np = img_out.numpy()
  fig, ax = plt.subplots(1, 2, figsize=(4, 4))
  ax[0].imshow(img_search_np)
  ax[0].set_title('Search Image')
  ax[1].imshow(img_out_np, cmap='gray', vmin=0, vmax=1)
  ax[1].set_title('Mask Image')
  plt.show()

In [None]:
class_ind_to_name = {0: 'Water', 1: 'Sand', 2: 'Grass', 3: 'Mountain', 4: 'Snow'}
ind = 24
print(class_ind_to_name[train_class[ind].item()])
plot_image(train_search, train_out, ind)

In [None]:
class BaseModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        config = st1_target.make_mixformer_config("medium")
        self.model = st1_target.MixFormer(config)

    def forward(self, _search, _class):
        return self.model(_search, _class)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, _search, _class, _out, _class_type=None):
        self._search = _search
        self._class = _class
        self._out = _out
        if _class_type is None:
          self._list_ind = list(range(len(_search)))
        else:
          self._list_ind = torch.where(_class == _class_type)[0].tolist()

    def __len__(self):
        return len(self._list_ind)

    def __getitem__(self, idx):
        return self._search[self._list_ind[idx]], self._class[self._list_ind[idx]], self._out[self._list_ind[idx]]

In [None]:
class LightningMixFormer(pl.LightningModule):
    def __init__(self, base_model=None):
        super().__init__()
        config = st1_search.make_mixformer_config("medium")
        self.model = st1_search.MixFormer(config, base_model)
        self._class_type = None
        self.start_lr = 1e-3
        self.lr_gamma = 0.75

    def forward(self, _search, _class):
        return self.model(_search, _class)

    def get_loss(self, out_pred, out):
        #loss = F.binary_cross_entropy(out_pred, out)
        loss = F.mse_loss(out_pred, out)
        return loss

    def training_step(self, batch, batch_idx):
        _search, _class, _out = batch
        out_pred = self.model(_search, _class)
        loss = self.get_loss(out_pred, _out)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        _search, _class, _out = batch
        out_pred = self.model(_search, _class)
        loss = self.get_loss(out_pred, _out)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.start_lr, weight_decay=1e-6)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_gamma)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1
            }
        }

    def train_dataloader(self):
        train_dataset = MyDataset(train_search, train_class, train_out, self._class_type)
        return torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)

    def val_dataloader(self):
        valid_dataset = MyDataset(valid_search, valid_class, valid_out, self._class_type)
        return torch.utils.data.DataLoader(valid_dataset, batch_size=512, shuffle=False, num_workers=2)

In [None]:
def get_trainer(max_epochs):
  checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my_model/',
                                        filename='model-{epoch:02d}-{val_loss:.2f}',
                                        save_top_k=5, mode='min')
  csv_logger = pl_loggers.CSVLogger('logs')
  trainer = pl.Trainer(max_epochs=max_epochs,callbacks=[checkpoint_callback],
                       logger=csv_logger)
  return trainer

In [None]:
base_model_v = 1
base_model_path_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage1_SimpleTargetPart/models'
base_model_path = os.path.join(base_model_path_prefix, f'model_medium_v{base_model_v}.ckpt')
base_model = BaseModel.load_from_checkpoint(base_model_path)

In [None]:
model = LightningMixFormer(base_model.model)

In [None]:
stages = [1, 2]

for stage in stages:
  if stage == 1:
    trainer = get_trainer(3)
    model.start_lr = 5e-4
    model.lr_gamma = 0.8
    model._class_type = None
    model.model.set_base_requires_grad(False)
    trainer.fit(model)
  elif stage == 2:
    trainer = get_trainer(20)
    model.start_lr = 1e-3
    model.lr_gamma = 0.865
    model._class_type = None
    model.model.set_base_requires_grad(True)
    trainer.fit(model)

In [None]:
import shutil

model_v = 1

#source_path = '/content/logs/lightning_logs/version_1/metrics.csv'
#dest_path = os.path.join(data_prefix, f'models/logs_medium_v{model_v}.csv')
#shutil.copyfile(source_path, dest_path)

trainer.save_checkpoint("model.ckpt")
model_checkpoint_path = os.path.join(data_prefix, f'models/model_medium_v{model_v}.ckpt')
trainer.save_checkpoint(model_checkpoint_path)

In [None]:
checkpoint_path = os.path.join(data_prefix, f'models/model_medium_v{model_v}.ckpt')
model = LightningMixFormer.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()
model.to('cuda')
pass

In [None]:
def get_outputs(searches, classes, outs):
  dataset = MyDataset(searches, classes, outs)
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=2)
  all_outputs = []
  with torch.no_grad():
      for batch in data_loader:
          _search = batch[0].to('cuda')
          _class = batch[1].to('cuda')
          outputs = model(_search, _class)
          outputs = torch.clamp(outputs, min=0, max=1)
          all_outputs.append(outputs.cpu())
  all_outputs = torch.cat(all_outputs, dim=0)
  return all_outputs

valid_model_outs = get_outputs(valid_search, valid_class, valid_out)
print(valid_model_outs.shape)

In [None]:
def plot_image2(searches, outs, outs_pred, index):
  plt.clf()
  img_search = searches[index]
  img_search_np = img_search.numpy()
  img_out = outs[index]
  img_out_np = img_out.numpy()
  img_out_pred = outs_pred[index]
  img_out_pred_np = img_out_pred.numpy()
  fig, ax = plt.subplots(1, 3, figsize=(9, 9))
  ax[0].imshow(img_search_np)
  ax[0].set_title('Search Image')
  ax[1].imshow(img_out_np, cmap='gray', vmin=0, vmax=1)
  ax[1].set_title('Mask Image')
  ax[2].imshow(img_out_pred_np, cmap='gray', vmin=0, vmax=1)
  ax[2].set_title('Predicted Mask Image')
  plt.show()

In [None]:
class_ind_to_name = {0: 'Water', 1: 'Sand', 2: 'Grass', 3: 'Mountain', 4: 'Snow'}
ind = 7 # 0, 2, 3, 7, 8
print(class_ind_to_name[valid_class[ind].item()])
plot_image2(valid_search, valid_out, valid_model_outs, ind)

In [None]:
import shutil

if True:
  try:
    shutil.rmtree("/content/logs")
  except:
    pass
  try:
    shutil.rmtree("/content/my_model")
  except:
    pass