In [None]:
!pip install pytorch-lightning > /dev/null 2>&1
!pip install einops > /dev/null 2>&1
!pip install timm > /dev/null 2>&1

In [None]:
!rm -rf MixformerFromScratch
!git clone https://github.com/reeWorlds/MixformerFromScratch.git
!pip install -e "MixformerFromScratch"

import site
site.main()

In [None]:
if False:
  import os
  os._exit(0)

In [None]:
import torch
import pytorch_lightning as pl
import numpy as np
import os
import gc

import matplotlib.pyplot as plt

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math
import seaborn as sns
from functools import reduce
from operator import mul
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from Mixformer import st4_mxformer

In [None]:
from google.colab import drive
drive.mount('/content/drive')
#path_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage4_Easy'
path_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage4_Hard'

In [None]:
train_patches_nums = list(range(40)) # up to 45
valid_patches_nums = list(range(40, 45))

trn_s, trn_t, trn_ans = None, None, None
vld_s, vld_t, vld_ans = None, None, None

def get_tensor_by_path(file_path, size, shape, dtype):
  mmapped_array = np.memmap(file_path, dtype=dtype, mode='r', shape=(size,))
  tensor = torch.from_numpy(mmapped_array)
  return tensor.reshape(*shape)

def get_data_by_num(path_num):
  s_path = os.path.join(path_prefix, f'patch{path_num}_search.bin')
  s_shape = (10000, 64, 64, 3)
  s = get_tensor_by_path(s_path, reduce(mul, s_shape), s_shape, np.float32)
  t_path = os.path.join(path_prefix, f'patch{path_num}_target.bin')
  t_shape = (10000, 48, 48, 3)
  t = get_tensor_by_path(t_path, reduce(mul, t_shape), t_shape, np.float32)
  ans_path = os.path.join(path_prefix, f'patch{path_num}_output.bin')
  #ans_shape = (10000, 3)
  ans_shape = (10000, 2)
  ans = get_tensor_by_path(ans_path, reduce(mul, ans_shape), ans_shape, np.float32)
  return s, t, ans

def get_data(nums):
  list_s, list_t, list_ans = [], [], []
  for patch_num in nums:
    s, t, ans = get_data_by_num(patch_num)
    list_s.append(s)
    list_t.append(t)
    list_ans.append(ans)
    if patch_num % 5 == 0:
      print(f'Finished patch_num = {patch_num}')
  tensor_s = torch.cat(list_s, dim=0)
  tensor_t = torch.cat(list_t, dim=0)
  tensor_ans = torch.cat(list_ans, dim=0)
  return tensor_s, tensor_t, tensor_ans

trn_s, trn_t, trn_ans = get_data(train_patches_nums)
vld_s, vld_t, vld_ans = get_data(valid_patches_nums)

gc.collect()

print(f'train data shapes are s:{trn_s.shape} t:{trn_t.shape} ans:{trn_ans.shape}')
print(f'train data shapes are s:{vld_s.shape} t:{vld_t.shape} ans:{vld_ans.shape}')

In [None]:
def plot_image(img_s, img_t, img_ans, idx):
  plt.clf()
  img_s_np = img_s[idx].numpy()
  img_t_np = img_t[idx].numpy()
  img_xy_np = img_ans[idx,0:2].numpy()
  fig, ax = plt.subplots(1, 2, figsize=(6, 3))
  ax[0].imshow(img_s_np)
  ax[0].set_title('Search')
  ax[1].imshow(img_t_np)
  ax[1].set_title('Target')
  cx, cy = img_xy_np
  pos_x_search = int(cx * 64)
  pos_y_search = int(cy * 64)
  ax[0].scatter(pos_x_search, pos_y_search, color='red', s=25)
  plt.show()

In [None]:
idx = 14
plot_image(trn_s, trn_t, trn_ans, idx)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, _s, _t, _ans):
        self._s = _s
        self._t = _t
        self._ans = _ans

    def __len__(self):
        return len(self._ans)

    def __getitem__(self, idx):
        return self._s[idx], self._t[idx], self._ans[idx]

In [None]:
size_str = 'medium'

In [None]:
#def _calc_loss(_pred, _ref):
#  _pred_pos, _pred_scale = _pred[:,0:2], _pred[:,2:3]
#  _ref_pos, _ref_scale = _ref[:,0:2], _ref[:,2:3]
#  loss_pos = F.mse_loss(_pred_pos, _ref_pos)
#  loss_scale = F.mse_loss(_pred_scale, _ref_scale)
#  loss = loss_pos + 0.1 * loss_scale
#  return loss

def _calc_loss(_pred, _ref):
  pred_pos = _pred[:,0:2]
  loss = F.mse_loss(pred_pos, _ref)
  return loss

class LightningMixFormer(pl.LightningModule):
  def __init__(self):
    super().__init__()
    config = st4_mxformer.make_mixformer_config(size_str)
    self.model = st4_mxformer.MixFormer(config)
    self.start_lr = 1e-3
    self.lr_gamma = 0.75

  def forward(self, _s, _t):
    return self.model(_s, _t)

  def training_step(self, batch, batch_idx):
    _s, _t, _m_ref = batch
    _m_pred = self.model(_s, _t)
    loss =_calc_loss(_m_pred, _m_ref)
    self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    _s, _t, _m_ref = batch
    _m_pred = self.model(_s, _t)
    loss =_calc_loss(_m_pred, _m_ref)
    self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.start_lr, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_gamma)
    return {'optimizer': optimizer,
            'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch', 'frequency': 1} }

  def train_dataloader(self):
    train_dataset = MyDataset(trn_s, trn_t, trn_ans)
    return DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

  def val_dataloader(self):
    valid_dataset = MyDataset(vld_s, vld_t, vld_ans)
    return DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
def get_trainer(max_epochs):
  checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my_model/',
                                        filename='model-{epoch:02d}-{val_loss:.2f}',
                                        save_top_k=5, mode='min')
  csv_logger = pl_loggers.CSVLogger('logs')
  trainer = pl.Trainer(max_epochs=max_epochs,callbacks=[checkpoint_callback], logger=csv_logger)
  return trainer

In [None]:
load_pretrained = True
if load_pretrained == True:
  path_pret_pref = '/content/drive/My Drive/Data/DiplomeGenerated/Stage4_Easy'
  path_pret = os.path.join(path_pret_pref, f'models/model_{size_str}.ckpt')
  model = LightningMixFormer.load_from_checkpoint(path_pret)
else:
  model = LightningMixFormer()

In [None]:
trainer = get_trainer(10)
model.start_lr = 1e-3
model.lr_gamma = 0.85
trainer.fit(model)

In [None]:
trainer.save_checkpoint("model.ckpt")
checkpoint_path = os.path.join(path_prefix, f'models/model_{size_str}.ckpt')
trainer.save_checkpoint(checkpoint_path)

In [None]:
checkpoint_path = os.path.join(path_prefix, f'models/model_{size_str}.ckpt')
model = LightningMixFormer.load_from_checkpoint(checkpoint_path=checkpoint_path)
model = model.eval().to('cuda')

In [None]:
def get_outputss(searches, targets, anss):
  dataset = MyDataset(searches, targets, anss)
  data_loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)
  all_anss = []
  with torch.no_grad():
      for batch in data_loader:
          _search = batch[0].to('cuda')
          _target = batch[1].to('cuda')
          ans = model(_search, _target)
          ans = ans[:, 0:2]
          ans = torch.clamp(ans, min=0, max=1)
          all_anss.append(ans.to('cpu').detach())
  all_anss = torch.cat(all_anss, dim=0)
  return all_anss

valid_anss = get_outputss(vld_s, vld_t, vld_ans)
print(valid_anss.shape)

In [None]:
print(torch.cat([valid_anss[0:10], vld_ans[0:10]], dim=1))

In [None]:
loss = _calc_loss(valid_anss, vld_ans)
print(loss)
#print(f"ref loss = {_calc_loss(torch.zeros(vld_ans.shape[0], 3) + 0.5, vld_ans)}")
print(f"ref loss = {_calc_loss(torch.zeros(vld_ans.shape[0], 2) + 0.5, vld_ans)}")

In [None]:
def plot_image2(img_s, img_t, img_ans, img_ans_p, idx):
  plt.clf()
  img_s_np = img_s[idx].numpy()
  img_t_np = img_t[idx].numpy()
  img_xy_np = img_ans[idx,0:2].numpy()
  img_xy_p_np = img_ans_p[idx,0:2].numpy()
  fig, ax = plt.subplots(1, 2, figsize=(6, 3))
  ax[0].imshow(img_s_np)
  ax[0].set_title('Search')
  ax[1].imshow(img_t_np)
  ax[1].set_title('Target')
  cx, cy = img_xy_np
  pos_x_search = int(cx * 64)
  pos_y_search = int(cy * 64)
  ax[0].scatter(pos_x_search, pos_y_search, color='red', s=15)
  cx, cy = img_xy_p_np
  pos_x_search = int(cx * 64)
  pos_y_search = int(cy * 64)
  ax[0].scatter(pos_x_search, pos_y_search, color='yellow', s=10)
  plt.show()

In [None]:
idx = 15
plot_image2(vld_s, vld_t, vld_ans, valid_anss, idx)

In [None]:
def count_exceeding_distances(vld_ans, valid_anss, threshold):
  distances = torch.norm(vld_ans - valid_anss, dim=1, p=2)
  exceeding = distances > threshold
  count = torch.sum(exceeding.int())
  return count
cnt_bad = count_exceeding_distances(vld_ans, valid_anss, 0.1)
cnt_good = vld_ans.shape[0] - cnt_bad
print(f"% of good is {cnt_good / vld_ans.shape[0] * 100:.2f}%")

In [None]:
import shutil

if True:
  try:
    shutil.rmtree("/content/logs")
  except:
    pass
  try:
    shutil.rmtree("/content/my_model")
  except:
    pass