In [None]:
!pip install pytorch-lightning > /dev/null 2>&1
!pip install einops > /dev/null 2>&1
!pip install timm > /dev/null 2>&1

In [None]:
!rm -rf MixformerFromScratch
!git clone https://github.com/reeWorlds/MixformerFromScratch.git
!pip install -e "MixformerFromScratch"

import site
site.main()

In [None]:
if False:
  import os
  os._exit(0)

In [None]:
import torch
import pytorch_lightning as pl
import numpy as np
import os
import gc

import matplotlib.pyplot as plt

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

from Mixformer import st2_ae

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_prefix = '/content/drive/My Drive/Data/DiplomeGenerated/Stage2_AE'

In [None]:
data_folder_path = data_prefix

train_patches_nums = list(range(21)) # up to 21

train_data = None, None

def get_tensor_by_path(file_path, size, shape, dtype):
  mmapped_array = np.memmap(file_path, dtype=dtype, mode='r', shape=(size,))
  tensor = torch.from_numpy(mmapped_array)
  return tensor.reshape(*shape)

def get_data_by_num(path_num):
  data_path = os.path.join(data_folder_path, f'patch{path_num}_64x64.bin')
  data_size = 10000 * 64 * 64 * 3
  data_tensor = get_tensor_by_path(data_path, data_size, (10000, 64, 64, 3), np.float32)
  return data_tensor
list_data = []

for patch_num in train_patches_nums:
  d = get_data_by_num(patch_num)
  list_data.append(d)
  if patch_num % 4 == 0:
    print(f'Finished patch_num = {patch_num}')

train_data = torch.cat(list_data, dim=0)

gc.collect()

print(f'train data shapes are d:{train_data.shape}')

In [None]:
def plot_image(data, index):
  plt.clf()
  img_data = data[index]
  img_data_np = img_data.numpy()
  fig, ax = plt.subplots(1, 1, figsize=(3, 3))
  ax.imshow(img_data_np)
  ax.set_title('Image')
  plt.show()

In [None]:
ind = 0
plot_image(train_data, ind)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, _data):
        self._data = _data

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return self._data[idx]

In [None]:
class LightningMixFormer(pl.LightningModule):
  def __init__(self):
    super().__init__()
    config = st2_ae.ConfigGeneration.make_ae_config()
    self.model = st2_ae.Autoencoder(config)
    self.start_lr = 1e-3
    self.lr_gamma = 0.86

  def forward(self, _data):
    return self.model(_data)

  def get_loss(self, _data, _data_pred):
    #loss = F.binary_cross_entropy(_data_pred, _data)
    loss = F.mse_loss(_data_pred, _data)
    return loss

  def training_step(self, batch, batch_idx):
    data_out = self.model(batch)
    loss = self.get_loss(batch, data_out)
    self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.start_lr, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_gamma)
    return {'optimizer': optimizer,
            'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch', 'frequency': 1} }

  def train_dataloader(self):
    train_dataset = MyDataset(train_data)
    return torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
def get_trainer(max_epochs):
  checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath='my_model/',
                                        filename='model-{epoch:02d}-{val_loss:.2f}',
                                        save_top_k=5, mode='min')
  csv_logger = pl_loggers.CSVLogger('logs')
  trainer = pl.Trainer(max_epochs=max_epochs,callbacks=[checkpoint_callback],
                       logger=csv_logger)
  return trainer

In [None]:
model = LightningMixFormer()

In [None]:
trainer = get_trainer(12)
model.start_lr = 1e-3
model.lr_gamma = 0.75
trainer.fit(model)

In [None]:
import shutil

model_v = 1

trainer.save_checkpoint("model.ckpt")
model_checkpoint_path = os.path.join(data_prefix, f'models/model_v{model_v}.ckpt')
trainer.save_checkpoint(model_checkpoint_path)

In [None]:
checkpoint_path = os.path.join(data_prefix, f'models/model_v{model_v}.ckpt')
model = LightningMixFormer.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()
model.to('cuda')
pass

In [None]:
def get_outputs(datas):
  dataset = MyDataset(datas)
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, num_workers=2)
  all_out_data = []
  with torch.no_grad():
    for batch in data_loader:
      _data = batch.to('cuda')
      out_d = model(_data)
      all_out_data.append(out_d.cpu())
  all_out_data = torch.cat(all_out_data, dim=0)
  return all_out_data

train_data_outs = get_outputs(train_data[0:1000])

In [None]:
def plot_image2(d, d_out, index):
  plt.clf()
  img_d_np = d[index].numpy()
  img_d_out_np = d_out[index].clamp(0, 1).numpy()
  fig, ax = plt.subplots(1, 2, figsize=(6, 6))
  ax[0].imshow(img_d_np)
  ax[0].set_title('Image')
  ax[1].imshow(img_d_out_np)
  ax[1].set_title('Image_out')
  plt.show()

In [None]:
ind = 5
plot_image2(train_data, train_data_outs, ind)

In [None]:
import shutil

if True:
  try:
    shutil.rmtree("/content/logs")
  except:
    pass
  try:
    shutil.rmtree("/content/my_model")
  except:
    pass