In [None]:
!pip install pytorch-lightning > /dev/null 2>&1
!pip install einops > /dev/null 2>&1
!pip install timm > /dev/null 2>&1

In [None]:
!rm -rf MixformerFromScratch
!git clone https://github.com/reeWorlds/MixformerFromScratch.git
!pip install -e "MixformerFromScratch"

import site
site.main()

In [None]:
if False:
  import os
  os._exit(0)

In [None]:
import torch
import pytorch_lightning as pl
import numpy as np
import os
import gc

import matplotlib.pyplot as plt

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math
import seaborn as sns
from functools import reduce
from operator import mul
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from Mixformer import st2_xformer
from Mixformer import st3_mxformer

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage3'

In [None]:
data_folder_path = data_prefix

train_patches_nums = list(range(10)) # up to 10
valid_pathch_num = 10

trn_s, trn_t, trn_m = None, None, None
vld_s, vld_t, vld_m = None, None, None

def get_tensor_by_path(file_path, size, shape, dtype):
  mmapped_array = np.memmap(file_path, dtype=dtype, mode='r', shape=(size,))
  tensor = torch.from_numpy(mmapped_array)
  return tensor.reshape(*shape)

def get_data_by_num(path_num):
  s_path = os.path.join(data_folder_path, f'patch{path_num}_search.bin')
  s_shape = (10000, 64, 64, 3)
  s = get_tensor_by_path(s_path, reduce(mul, s_shape), s_shape, np.float32)
  t_path = os.path.join(data_folder_path, f'patch{path_num}_target.bin')
  t_shape = (10000, 48, 48, 3)
  t = get_tensor_by_path(t_path, reduce(mul, t_shape), t_shape, np.float32)
  m_path = os.path.join(data_folder_path, f'patch{path_num}_mask.bin')
  m_shape = (10000, 64, 64)
  m = get_tensor_by_path(m_path, reduce(mul, m_shape), m_shape, np.float32)
  return s, t, m

list_s, list_t, list_m = [], [], []

for patch_num in train_patches_nums:
  s, t, m = get_data_by_num(patch_num)
  list_s.append(s)
  list_t.append(t)
  list_m.append(m)
  if patch_num % 2 == 0:
    print(f'Finished patch_num = {patch_num}')

trn_s = torch.cat(list_s, dim=0)
trn_t = torch.cat(list_t, dim=0)
trn_m = torch.cat(list_m, dim=0)

vld_s, vld_t, vld_m = get_data_by_num(valid_pathch_num)

gc.collect()

print(f'train data shapes are s:{trn_s.shape} t:{trn_t.shape} m:{trn_m.shape}')
print(f'train data shapes are s:{vld_s.shape} t:{vld_t.shape} m:{vld_m.shape}')

In [None]:
def plot_image(img_s, img_t, img_msk, idx):
  plt.clf()
  img_s_np = img_s[idx].numpy()
  img_t_np = img_t[idx].numpy()
  img_m_np = img_msk[idx].numpy()
  fig, ax = plt.subplots(1, 3, figsize=(7, 3))
  ax[0].imshow(img_s_np)
  ax[0].set_title('Search')
  ax[1].imshow(img_t_np)
  ax[1].set_title('Target')
  ax[2].imshow(img_m_np, cmap='gray', vmin=0, vmax=1)
  ax[2].set_title('Mask')
  plt.show()

In [None]:
idx = 0
plot_image(trn_s, trn_t, trn_m, idx)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, _s, _t, _m):
        self._s = _s
        self._t = _t
        self._m = _m

    def __len__(self):
        return len(self._s)

    def __getitem__(self, idx):
        return self._s[idx], self._t[idx], self._m[idx]

In [None]:
size_str = 'large'

In [None]:
class LightningBaseModel(pl.LightningModule):
  def __init__(self):
    super().__init__()
    config = st2_xformer.make_transformer_config(size_str)
    self.model = st2_xformer.Transformer(config)

  def forward(self, _d):
    return self.model(_d)

In [None]:
class LightningMixFormer(pl.LightningModule):
  def __init__(self, base_model=None):
    super().__init__()
    config = st3_mxformer.make_mixformer_config(size_str)
    if base_model is None:
      self.model = st3_mxformer.MixFormer(config)
    else:
      self.model = st3_mxformer.MixFormer(config, base_model)
    self.start_lr = 1e-3
    self.lr_gamma = 0.75

  def forward(self, _s, _t):
    return self.model(_s, _t)

  def training_step(self, batch, batch_idx):
    _s, _t, _m_ref = batch
    _m_pred = self.model(_s, _t)
    loss = F.mse_loss(_m_pred, _m_ref)
    self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    _s, _t, _m_ref = batch
    _m_pred = self.model(_s, _t)
    loss = F.mse_loss(_m_pred, _m_ref)
    self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.start_lr, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_gamma)
    return {'optimizer': optimizer,
            'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch', 'frequency': 1} }

  def train_dataloader(self):
    train_dataset = MyDataset(trn_s, trn_t, trn_m)
    return DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

  def val_dataloader(self):
    valid_dataset = MyDataset(vld_s, vld_t, vld_m)
    return DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
def get_trainer(max_epochs):
  checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my_model/',
                                        filename='model-{epoch:02d}-{val_loss:.2f}',
                                        save_top_k=5, mode='min')
  csv_logger = pl_loggers.CSVLogger('logs')
  trainer = pl.Trainer(max_epochs=max_epochs,callbacks=[checkpoint_callback], logger=csv_logger)
  return trainer

In [None]:
base_model_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage2'
base_model_path = os.path.join(base_model_prefix, f'models/model_{size_str}.ckpt')
base_model = LightningBaseModel.load_from_checkpoint(base_model_path)

In [None]:
model = LightningMixFormer(base_model.model)

In [None]:
stages = [1, 2]

for stage in stages:
  if stage == 1:
    trainer = get_trainer(1)
    model.start_lr = 5e-4
    model.lr_gamma = 0.8
    model.model.set_base_requires_grad(False)
    trainer.fit(model)
  elif stage == 2:
    trainer = get_trainer(5)
    model.start_lr = 1e-3
    model.lr_gamma = 0.75
    model.model.set_base_requires_grad(True)
    trainer.fit(model)

In [None]:
model_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage3'

trainer.save_checkpoint("model.ckpt")
checkpoint_path = os.path.join(model_prefix, f'models/model_{size_str}.ckpt')
trainer.save_checkpoint(checkpoint_path)

In [None]:
checkpoint_path = os.path.join(model_prefix, f'models/model_{size_str}.ckpt')
model = LightningMixFormer.load_from_checkpoint(checkpoint_path=checkpoint_path)
model = model.eval().to('cuda')

In [None]:
def get_masks(searches, targets, outs):
  dataset = MyDataset(searches, targets, outs)
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, num_workers=2)
  all_masks = []
  with torch.no_grad():
      for batch in data_loader:
          _search = batch[0].to('cuda')
          _target = batch[1].to('cuda')
          masks = model(_search, _target)
          masks = torch.clamp(masks, min=0, max=1)
          all_masks.append(masks.cpu())
  all_masks = torch.cat(all_masks, dim=0)
  return all_masks

valid_model_masks = get_masks(vld_s, vld_t, vld_m)
#valid_model_masks = get_masks(trn_s[0:1000], trn_t[0:1000], trn_m[0:1000])
print(valid_model_masks.shape)

In [None]:
def plot_image2(ss, tt, mm, mm_pred, index):
  plt.clf()
  img_search = ss[index]
  img_search_np = img_search.numpy()
  img_target = tt[index]
  img_target_np = img_target.numpy()
  img_out = mm[index]
  img_out_np = img_out.numpy()
  img_out_pred = mm_pred[index]
  img_out_pred_np = img_out_pred.numpy()
  fig, ax = plt.subplots(2, 2, figsize=(6, 6))
  ax[0,0].imshow(img_search_np)
  ax[0,0].set_title('Search Image')
  ax[0,1].imshow(img_target_np)
  ax[0,1].set_title('Target Image')
  ax[1,0].imshow(img_out_np, cmap='gray', vmin=0, vmax=1)
  ax[1,0].set_title('Mask')
  ax[1,1].imshow(img_out_pred_np, cmap='gray', vmin=0, vmax=1)
  ax[1,1].set_title('Predicted Mask')
  plt.show()

In [None]:
ind = 2
plot_image2(vld_s, vld_t, vld_m, valid_model_masks, ind)
#plot_image2(trn_s, trn_t, trn_m, valid_model_masks, ind)

In [None]:
loss = F.mse_loss(valid_model_masks, vld_m)
#loss = F.mse_loss(valid_model_masks, trn_m[0:1000])
print(f"loss = {loss}")

In [None]:
import shutil

if True:
  try:
    shutil.rmtree("/content/logs")
  except:
    pass
  try:
    shutil.rmtree("/content/my_model")
  except:
    pass