In [None]:
!pip install pytorch-lightning > /dev/null 2>&1
!pip install einops > /dev/null 2>&1
!pip install timm > /dev/null 2>&1

In [None]:
!rm -rf MixformerFromScratch
!git clone https://github.com/reeWorlds/MixformerFromScratch.git
!pip install -e "MixformerFromScratch"

import site
site.main()

In [None]:
if False:
  import os
  os._exit(0)

In [None]:
import torch
import pytorch_lightning as pl
import numpy as np
import os
import gc

import matplotlib.pyplot as plt

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math
import seaborn as sns
from functools import reduce
from operator import mul
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from Mixformer import st2_xformer

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage2'

In [None]:
data_folder_path = data_prefix

train_patches_nums = list(range(20)) # up to 20
valid_pathch_num = 20

trn_s, trn_t, trn_s_stats, trn_t_stats, trn_types, trn_infos = None, None, None, None, None, None
vld_s, vld_t, vld_s_stats, vld_t_stats, vld_types, vld_infos = None, None, None, None, None, None

def get_tensor_by_path(file_path, size, shape, dtype):
  mmapped_array = np.memmap(file_path, dtype=dtype, mode='r', shape=(size,))
  tensor = torch.from_numpy(mmapped_array)
  return tensor.reshape(*shape)

def get_data_by_num(path_num):
  s_path = os.path.join(data_folder_path, f'patch{path_num}_64x64.bin')
  s_shape = (10000, 64, 64, 3)
  s = get_tensor_by_path(s_path, reduce(mul, s_shape), s_shape, np.float32)
  t_path = os.path.join(data_folder_path, f'patch{path_num}_48x48.bin')
  t_shape = (10000, 48, 48, 3)
  t = get_tensor_by_path(t_path, reduce(mul, t_shape), t_shape, np.float32)
  s_stats_path = os.path.join(data_folder_path, f'patch{path_num}_64x64_stats.bin')
  s_stats_shape = (10000, 8, 8, 5)
  s_stats = get_tensor_by_path(s_stats_path, reduce(mul, s_stats_shape), s_stats_shape, np.float32)
  t_stats_path = os.path.join(data_folder_path, f'patch{path_num}_48x48_stats.bin')
  t_stats_shape = (10000, 6, 6, 5)
  t_stats = get_tensor_by_path(t_stats_path, reduce(mul, t_stats_shape), t_stats_shape, np.float32)
  labels_path = os.path.join(data_folder_path, f'patch{path_num}_labels.pt')
  labels = torch.load(labels_path)
  types = F.one_hot(labels, 20).float()
  infos_path = os.path.join(data_folder_path, f'patch{path_num}_info.bin')
  infos_shape = (10000, 3)
  infos = get_tensor_by_path(infos_path, reduce(mul, infos_shape), infos_shape, np.float32)
  return s, t, s_stats, t_stats, types, infos

list_s, list_t, list_s_stats, list_t_stats, list_types, list_infos = [], [], [], [], [], []

for patch_num in train_patches_nums:
  s, t, s_stats, t_stats, types, infos = get_data_by_num(patch_num)
  list_s.append(s)
  list_t.append(t)
  list_s_stats.append(s_stats)
  list_t_stats.append(t_stats)
  list_types.append(types)
  list_infos.append(infos)
  if patch_num % 4 == 0:
    print(f'Finished patch_num = {patch_num}')

trn_s = torch.cat(list_s, dim=0)
trn_t = torch.cat(list_t, dim=0)
trn_s_stats = torch.cat(list_s_stats, dim=0)
trn_t_stats = torch.cat(list_t_stats, dim=0)
trn_types = torch.cat(list_types, dim=0)
trn_infos = torch.cat(list_infos, dim=0)

vld_s, vld_t, vld_s_stats, vld_t_stats, vld_types, vld_infos = get_data_by_num(valid_pathch_num)

gc.collect()

print(f'train data shapes are s:{trn_s.shape} t:{trn_t.shape} s_s:{trn_s_stats.shape} \n t_s:{trn_t_stats.shape} tp:{trn_types.shape} inf:{trn_infos.shape}')
print(f'train data shapes are s:{vld_s.shape} t:{vld_t.shape} s_s:{vld_s_stats.shape} \n t_s:{vld_t_stats.shape} tp:{vld_types.shape} inf:{vld_infos.shape}')

In [None]:
map_names = ['Water', 'Sand', 'Grass', 'Hill', 'Mountain']
def plot_image(img, msk, idx):
  plt.clf()
  img_np = img[idx].numpy()
  fig, ax = plt.subplots(1, 6, figsize=(10, 2))
  ax[0].imshow(img_np)
  ax[0].set_title('Image')
  for i in range(5):
    ax[1 + i].imshow(msk[idx,:,:,i].numpy())
    ax[1 + i].set_title(map_names[i])
  plt.show()

In [None]:
idx = 5
print(f'type = {torch.argmax(trn_types[idx])}')
print(f'info = {trn_infos[idx]}')
plot_image(trn_s, trn_s_stats, idx)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, _s, _t, _s_stats, _t_stats, _types, _infos):
        self._s = _s
        self._t = _t
        self._s_stats = _s_stats
        self._t_stats = _t_stats
        self._types = _types
        self._infos = _infos

    def __len__(self):
        return len(self._s)

    def __getitem__(self, idx):
        return (self._s[idx], self._t[idx]), (self._s_stats[idx], self._t_stats[idx], self._types[idx], self._infos[idx])

In [None]:
def get_stats_loss(pred, ref):
  return F.mse_loss(pred, ref)

def get_types_loss(pred, ref):
  return F.binary_cross_entropy(pred, ref)

def get_infos_loss(pred, ref):
  loss_scale = F.mse_loss(pred[:,0], ref[:,0])
  loss_pers = F.mse_loss(pred[:,1], ref[:,1])
  loss_lacu = F.mse_loss(pred[:,2], ref[:,2])
  loss_sum = (loss_scale + loss_pers + loss_lacu) / 3.0
  return loss_sum

def get_full_loss(pred, ref):
  p_stats, p_types, p_infos = pred
  r_stats, t_types, r_infos = ref
  loss_stats = get_stats_loss(p_stats, r_stats)
  loss_types = get_types_loss(p_types, t_types)
  loss_infos = get_infos_loss(p_infos, r_infos)
  loss = 2.0 * loss_stats + 0.5 * loss_types + 0.1 * loss_infos
  return loss

class LightningTransfromer(pl.LightningModule):
  def __init__(self):
    super().__init__()
    config = st2_xformer.ConfigGeneration.make_transformer_config('large')
    self.model = st2_xformer.Transformer(config)
    self.start_lr = 1e-3
    self.lr_gamma = 0.75

  def forward(self, _d):
    return self.model(_d)

  def training_step(self, batch, batch_idx):
    data, outs = batch # (_s, _t) (_s_stats, _t_stats, _types, _infos)
    s_preds = self.model(data[0])
    t_preds = self.model(data[1])
    loss_s = get_full_loss(s_preds, (outs[0], outs[2], outs[3]))
    loss_t = get_full_loss(t_preds, (outs[1], outs[2], outs[3]))
    loss = loss_s * 1.0 + loss_t * 1.0
    self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    data, outs = batch # (_s, _t) (_s_stats, _t_stats, _types, _infos)
    s_preds = self.model(data[0])
    t_preds = self.model(data[1])
    loss_s = get_full_loss(s_preds, (outs[0], outs[2], outs[3]))
    loss_t = get_full_loss(t_preds, (outs[1], outs[2], outs[3]))
    loss = loss_s * 1.0 + loss_t * 1.0
    self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.start_lr, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_gamma)
    return {'optimizer': optimizer,
            'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch', 'frequency': 1} }

  def train_dataloader(self):
    train_dataset = MyDataset(trn_s, trn_t, trn_s_stats, trn_t_stats, trn_types, trn_infos)
    return DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

  def val_dataloader(self):
    valid_dataset = MyDataset(vld_s, vld_t, vld_s_stats, vld_t_stats, vld_types, vld_infos)
    return DataLoader(valid_dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
def get_trainer(max_epochs):
  checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my_model/',
                                        filename='model-{epoch:02d}-{val_loss:.2f}',
                                        save_top_k=5, mode='min')
  csv_logger = pl_loggers.CSVLogger('logs')
  trainer = pl.Trainer(max_epochs=max_epochs,callbacks=[checkpoint_callback],
                       logger=csv_logger)
  return trainer

In [None]:
model = LightningTransfromer()

In [None]:
trainer = get_trainer(10)
model.start_lr = 1e-3
model.lr_gamma = 0.75
trainer.fit(model)

In [None]:
import shutil

model_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage2'

trainer.save_checkpoint("model.ckpt")
checkpoint_path = os.path.join(model_prefix, f'models/model_large.ckpt')
trainer.save_checkpoint(checkpoint_path)

In [None]:
checkpoint_path = os.path.join(model_prefix, f'models/model_large.ckpt')
model = LightningTransfromer.load_from_checkpoint(checkpoint_path=checkpoint_path)
model = model.eval().to('cuda')

In [None]:
@torch.no_grad()
def get_outs(images):
  outs = []
  dataloader = DataLoader(images, batch_size=128, shuffle=False, num_workers=2)
  for batch in dataloader:
    batch = batch.to('cuda')
    out = model(batch)
    outs.append([o.detach().to('cpu') for o in out])
  all_stats = torch.cat([o[0] for o in outs], dim=0)
  all_types = torch.cat([o[1] for o in outs], dim=0)
  all_infos = torch.cat([o[2] for o in outs], dim=0)
  return (all_stats, all_types, all_infos)

In [None]:
vld_stats_o, vld_types_o, vld_infos_o = get_outs(vld_s)

In [None]:
def plot_image2(img, msk, msk_ref, idx):
  plt.clf()
  img_np = img[idx].numpy()
  fig, ax = plt.subplots(3, 5, figsize=(10, 6))
  ax[0, 2].imshow(img_np)
  ax[0, 2].set_title('Image')
  for i in range(5):
    ax[1, i].imshow(msk[idx,:,:,i].numpy())
    ax[1, i].set_title(map_names[i])
    ax[2, i].imshow(msk_ref[idx,:,:,i].numpy())
    ax[2, i].set_title(map_names[i])
  plt.show()

In [None]:
idx = 0
print(f'type = {torch.argmax(vld_types[idx])}')
print(f'info = {vld_infos[idx]}')
plot_image2(vld_s, vld_s_stats, vld_stats_o, idx)

In [None]:
stats_loss = get_stats_loss(vld_stats_o, vld_s_stats).item()
print(f"stats_loss = {stats_loss}")

In [None]:
def get_types_accuracy(pred_types, ref_types):
  cnt_all = len(pred_types)
  pred_idx = torch.argmax(pred_types, dim=-1)
  ref_idx = torch.argmax(ref_types, dim=-1)
  cnt_good = torch.sum(torch.eq(pred_idx, ref_idx).int())
  return cnt_good, cnt_all
cnt_good, cnt_all = get_types_accuracy(vld_types_o, vld_types)
print(f"good = {cnt_good} out of all = {cnt_all}")
print(f"% good = {cnt_good/cnt_all * 100}")
types_loss = get_types_loss(vld_types_o, vld_types).item()
print(f"types_loss = {types_loss}")

In [None]:
def get_class_pred_accuracy(pred_types, ref_types):
  mat = [[0 for i in range(20)] for j in range(20)]
  pred_idx = torch.argmax(pred_types, dim=-1)
  ref_idx = torch.argmax(ref_types, dim=-1)
  for i in range(len(pred_idx)):
    mat[ref_idx[i].item()][pred_idx[i].item()] += 1
  return mat

mat = get_class_pred_accuracy(vld_types_o, vld_types)
mat = np.array(mat)
row_sums = mat.sum(axis=1, keepdims=True)
mat = mat / row_sums

plt.figure(figsize=(8, 8))
ax = sns.heatmap(mat, annot=True, fmt=".2f", cmap='viridis', annot_kws={"size": 7})
plt.title('Heatmap of Real vs Predicted Classes')
plt.xlabel('Predicted Classes')
plt.ylabel('Real Classes')
plt.show()

In [None]:
test_n = 5
test_data = torch.cat([vld_infos[0:5], vld_infos_o[0:5]], dim=1)
print(test_data)
infos_loss = get_infos_loss(vld_infos_o, vld_infos).item()
print(f"infos_loss = {infos_loss}")

In [None]:
pred = (vld_stats_o, vld_types_o, vld_infos_o)
ref = (vld_s_stats, vld_types, vld_infos)
full_loss = get_full_loss(pred, ref).item()
print(f"full_loss = {full_loss}")

In [None]:
import shutil

if True:
  try:
    shutil.rmtree("/content/logs")
  except:
    pass
  try:
    shutil.rmtree("/content/my_model")
  except:
    pass