In [None]:
import torch
import torch.cuda.amp
import pandas as pd
import datetime
import numpy as np
import os
import time
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm 
import util.lr_sched as lr_sched
import numpy as np

In [None]:
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import timm.optim.optim_factory as optim_factory

In [None]:
x_train = torch.rand(100, 1, 224,224)
y_train = torch.rand(100)
print(x_train.size())
print(y_train.size())

In [None]:
#mean, std取得
def get_mean_and_std(x_train):
    mean = x_train.mean(axis = (0,2,3)) #(B, 3, 512, 512 ) (3, *****)
    std = x_train.std(axis = (0,2,3))
    return mean, std
mean, std = get_mean_and_std(x_train)

In [None]:
mean

In [None]:
std

In [None]:
class make_dataset(torch.utils.data.Dataset):
    def __init__(self, x_train, y_train, mean = 0., std = 1.):
        self.x_train = x_train
        self.y_train = y_train
        self.mean = mean
        self.std = std
        
        
#         # 標準化
#         if isinstance(self.mean, (float, int)):
#             self.x_train -= self.mean
#         else:
#             self.x_train -= self.mean.reshape(1, 3, 1, 1)
        
#         if isinstance(self.std, (float, int)):
#             self.x_train /= self.std
#         else:
#             self.x_train /= self.std.reshape(1, 3, 1, 1)



    def __len__(self):
        return self.x_train.shape[0]

    def __getitem__(self, idx):
         x = torch.cat((self.x_train[idx], self.x_train[idx], self.x_train[idx]),0) #(3, 224, 224)へ

#         return self.x_train[idx].to(dtype=torch.float), self.y_train[idx]
         return x.to(dtype=torch.float), self.y_train[idx]

train_data = make_dataset(x_train, y_train, mean = mean, std = std)

In [None]:
x = train_data.__getitem__(19)[0]
print(x.shape)
x = x.permute(1, 2, 0)
plt.imshow(x.detach().numpy())

In [None]:
gpu_ids = [0, 1, 2, 3]
device = torch.device(f"cuda:{gpu_ids[0]}" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
batch_size = 256

dataloader_train = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True, num_workers = 30
)

In [None]:
# Model_MAE
import model_mae 

model_mae = model_mae.__dict__['mae_vit_large_patch16'](norm_pix_loss='store_true')

model_mae.to(device)

In [None]:
import torch.nn as nn

In [None]:
# effe_batch_size = batch_size * accum_iter * # gpus'
lr = 1e-3 * (batch_size*len(gpu_ids)) / 256

In [None]:
param_groups = optim_factory.add_weight_decay(model_mae, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.95))
loss_scaler = NativeScaler()
accum_iter = 1 #default 1

In [None]:
def write_out(df,log_csv_path,**kwargs):
#     print(df.head())
    df = pd.concat([df,pd.DataFrame.from_dict([kwargs])])
#     print(df.head())
    df.to_csv(log_csv_path,index=False)
    return df

In [None]:
mask_ratio = 0.75
n_epochs = 800

In [None]:
mask_ratio = 0.75
n_epochs = 800

# 学習
train_loss = []

log_csv_path = "./models/log.csv"
df = pd.DataFrame()
# df.to_csv(log_csv_path)
start_time = time.time()

model_mae.train()
model = torch.nn.DataParallel(model_mae, [0, 1, 2, 3])

for epoch in range(n_epochs):
    
    
    n_train = 0
    total_loss_train = 0
    
    optimizer.zero_grad()  # 勾配の初期化→MAEではエポックごとに初期化しいている
    
    for data_iter_step, (x, _) in tqdm(enumerate(dataloader_train)):
        
        n_batch = x.shape[0]
        
        #iter毎にLr_scheを使用
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(dataloader_train) + epoch)
        

        x = x.to(device)  # テンソルをGPUに移動
        
        #MAEにデータを入力
        #ampを使用
        with  torch.cuda.amp.autocast():
            loss, pred, mask = model(x, mask_ratio=mask_ratio)
            
        print(loss.shape)
        print(pred.shape)
        print(mask.shape)
        loss = loss.sum()
           
        # 誤差の逆伝播+# パラメータの更新がloss_scalerに含まている
        loss /= accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        total_loss_train += loss.item() * n_batch
        n_train += n_batch

    train_loss.append(total_loss_train / n_train)
    
    
#     if epoch % 20 == 0 or epoch + 1 == n_epochs:
#         torch.save(model.module.state_dict(), f"./models/{epoch}.pt")

    print('EPOCH: {}, Train_loss {:.3f}'.format(
        epoch,
        total_loss_train / n_train
    ))
    
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    
    df = write_out(df,log_csv_path,epoch=epoch,train_loss=train_loss, train_time = total_time_str)

In [None]:
#視覚化

In [None]:
#学習済みモデルの取り込み
import model_mae
model_mae = model_mae.__dict__['mae_vit_large_patch16'](norm_pix_loss='store_true')
model_mae.load_state_dict(torch.load('models/400.pt'))

In [None]:
img = x.detach().numpy()
plt.imshow(img)

In [None]:
# define the utils

imagenet_mean = np.array([mean, mean, mean])
imagenet_std = np.array([std, std, std])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)
    
       # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

In [None]:
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)