In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from vit_3d import ViT
from mae import MAE
import os
from glob import glob
import sys
import yaml
import time
import cv2
import h5py
import random
import logging
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from tensorboardX import SummaryWriter
from collections import OrderedDict
import multiprocessing as mp
from sklearn.metrics import f1_score, average_precision_score, roc_auc_score
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
import multiprocessing as mp
from einops import rearrange
from monai import transforms
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
#导入ddp

import torch.distributed
import torch.multiprocessing as mp
import argparse
from attrdict import AttrDict

import waterz
# import evaluate as ev
from skimage.metrics import adapted_rand_error as adapted_rand_ref
from skimage.metrics import variation_of_information as voi_ref
import warnings
warnings.filterwarnings("ignore")

#logging 
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

em_FAFB_dir = sorted(glob('/data/cyd0806/EM_raw_image/hdf_data/nii.gz/*.gz'))
em_data_dicts = [
    {"image": image_name}
    for image_name in em_FAFB_dir
]
#实例化tensorboard
record_dir = '/output/logs'
save_dir = '/data/cyd0806/EM_raw_image/MAE/DecisionMAE_32*160*160'
if not os.path.exists(record_dir):
    os.makedirs(record_dir)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

writer = SummaryWriter(log_dir= record_dir)
#logging记录em_FAFB_dir的长度
logger.info(f"em_FAFB_dir length: {len(em_FAFB_dir)}")
#tensorboard记录em_FAFB_dir的长度

with open('/data/ydchen/VLP/bigmodel/IJCAI23/MAE/config/pretraining_all.yaml', 'r') as f:
        cfg = AttrDict(yaml.safe_load(f))

# Random apply
class RandomApply(nn.Module):
    def __init__(self, fn, p=0.3):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 1*84*2048*2048
train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        # ScaleIntensityRanged(
        #     keys=["image"], a_min=100, a_max=255,
        #     b_min=0.0, b_max=1.0, clip=True,
        # ),
        CropForegroundd(keys=["image"], source_key="image"),
        RandomApply(Orientationd(keys=["image"], axcodes="RAS")),
        # transforms.RandAffined(
        #     keys=['image'],
        #     mode=('bilinear', 'nearest'),
        #     shear_range=(0.5, 0.5, 0.5),
        #     prob=1.0, spatial_size=(128,128,48),
        #     rotate_range=(0, np.pi/15),
        #    ),
        
        # 随机噪声
        transforms.RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.1),
        transforms.RandSpatialCropSamplesd(
            keys=["image"],
            roi_size=[160,160,32],
            random_size=False,
            num_samples=6,
        ),
        transforms.ScaleIntensityRangeD(
            keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True)

        # transforms.ScaleIntensityd(keys=["image"]),
        # user can also add other random transforms
        
        #transforms.ToTensord(keys=["image"]),
    ]
)

model = ViT(
                        image_size = 160,          # image size
                        frames = 32,               # number of frames
                        image_patch_size = 16,     # image patch size
                        frame_patch_size = 4,      # frame patch size
                        channels=1,
                        num_classes = 1000,
                        dim = 768,
                        depth = 12,
                        heads = 12,
                        mlp_dim = 3072,
                        dropout = 0.1,
                        emb_dropout = 0.1
                    )

mae = MAE(
    encoder = model,
    masking_ratio = 0.5,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6,       # anywhere from 1 to 8
    hog = False
)
# 定义优化器
lr = 1e-5

def calculate_lr(iters):
    if iters < 3000:
        current_lr = (lr - lr/10) * pow(1 - float(iters - 300) / 3000, cfg.TRAIN.power) + cfg.TRAIN.end_lr
    else:
        current_lr = lr/10
    return current_lr
    
optimizer = torch.optim.Adam(mae.parameters(), lr=lr, weight_decay=0.00001, amsgrad=True,betas=(0.9, 0.999))
#optimizer = torch.optim.SGD(mae.parameters(), lr=lr,momentum=0.9,weight_decay=0.00001)


In [None]:
mae_weight = '/data/cyd0806/EM_raw_image/MAE/DecisionMAE_32*160*160/mae_full_mae_32*160*160_760.pth'
mae.load_state_dict(torch.load(mae_weight),strict = False)

In [None]:
device = torch.device('cuda:0')
for i,batch_data in enumerate(train_loader):
    batch_data['image'] = rearrange(batch_data['image'], 'b c x y z -> b c z x y')
    images = batch_data['image'].to(device)
    break

In [None]:
from einops import rearrange,repeat
device = torch.device('cuda:0')
img = torch.randn(4,1,32,160,160).to(device)
# get patches
mae = mae.to(device)
img = img.to(device)
patches = mae.to_patch(img)
print(f'patches shape: {patches.shape}')
batch, num_patches, *_ = patches.shape

# patch to encoder tokens and add positions

tokens = mae.patch_to_emb(patches)
print(f'tokens shape: {tokens.shape}')
tokens = tokens + mae.encoder.pos_embedding[:, 1:(num_patches + 1)]
print(f'tokens shape: {tokens.shape}')
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked

num_masked = int(mae.masking_ratio * num_patches)
#rand_indices = mae.decision_net(patches).argsort(dim = -1)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
print(f'rand indices shape: {rand_indices.shape}')
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
print(f'masked indices shape: {masked_indices.shape}, unmasked indices shape: {unmasked_indices.shape}')
# get the unmasked tokens to be encoded

batch_range = torch.arange(batch, device = device)[:, None]
print(f'batch range shape: {batch_range.shape}')
print(f'batch range: {batch_range}')
tokens = tokens[batch_range, unmasked_indices]
print(f'tokens shape: {tokens.shape}')
# get the patches to be masked for the final reconstruction loss

masked_patches = patches[batch_range, masked_indices]
print(f'masked patches shape: {masked_patches.shape}')
if mae.HOG:
    masked_patches = rearrange(masked_patches.cpu().numpy(), 'b p (h w d) -> b p h w d', h = mae.encoder.image_patch_size, w = mae.encoder.image_patch_size,d = mae.encoder.frame_patch_size)
    for b in range(len(masked_patches)):
        for p in range(len(masked_patches[0])):
            _, temp_hog = hog(masked_patches[b,p,:,:,2], orientations=8, pixels_per_cell=(4, 4),
                    cells_per_block=(1, 1), visualize=True, multichannel=False)
            masked_patches[b,p] = repeat(temp_hog, 'h w -> h w d', d = mae.encoder.frame_patch_size)
    masked_patches = torch.tensor(rearrange(masked_patches, 'b p h w d -> b p (h w d)')).to(device)
# attend with vision transformer

encoded_tokens = mae.encoder.transformer(tokens)

# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder

decoder_tokens = mae.enc_to_dec(encoded_tokens)

# reapply decoder position embedding to unmasked tokens

unmasked_decoder_tokens = decoder_tokens + mae.decoder_pos_emb(unmasked_indices)

# repeat mask tokens for number of masked, and add the positions using the masked indices derived above

mask_tokens = repeat(mae.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + mae.decoder_pos_emb(masked_indices)

# concat the masked tokens to the decoder tokens and attend with decoder

decoder_tokens = torch.zeros(batch, num_patches, mae.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = mae.decoder(decoder_tokens)
print(f'decoded tokens shape: {decoded_tokens.shape}')

# splice out the mask tokens and project to pixel values

mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = mae.to_pixels(mask_tokens)
print(f'pred pixel values shape: {pred_pixel_values.shape}')
pred_pixel_values_sigmoid = torch.sigmoid(pred_pixel_values)
        
# calculate reconstruction loss

recon_loss = F.mse_loss(pred_pixel_values, masked_patches)

In [None]:
# 给图像增加高斯模糊
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
import math

def gaussian_blur(img, kernel_size, sigma):
    # 高斯模糊
    img = cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma)
    return img

def random_gaussian_blur(img):
    # 随机高斯模糊
    kernel_size = random.choice([5,7,9])
    sigma = random.choice([1, 2, 3])
    img = gaussian_blur(img, kernel_size, sigma)
    img = img + np.random.randn(*img.shape) * random.choice([0.02, 0.03, 0.05])
    return img

def random_rotate(img):
    # 随机旋转
    angle = random.choice([0, 90, 180, 270])
    img = np.rot90(img, angle // 90)
    return img

def random_crop(img, crop_size):
    # 随机裁剪
    h, w = img.shape[:2]
    y = random.randint(0, h - crop_size)
    x = random.randint(0, w - crop_size)
    img = img[y:y + crop_size, x:x + crop_size]
    return img

def random_flip(img):
    # 随机翻转
    flip = random.choice([0, 1, 2])
    img = np.flip(img, flip)
    return img

In [None]:
masked_patches = rearrange(masked_patches.cpu().numpy(), 'b p (h w d) -> b p h w d', h = mae.encoder.image_patch_size, w = mae.encoder.image_patch_size,d = mae.encoder.frame_patch_size)
for i in range(masked_patches.shape[1]):
    for j in range(masked_patches.shape[2]):
        masked_patches[0,i,j,:,:] = random_gaussian_blur(masked_patches[0,i,j,:,:])
masked_patches = torch.tensor(rearrange(masked_patches, 'b p h w d -> b p (h w d)')).to(device)

In [None]:
recons_tokens = torch.zeros(batch, num_patches, 1024, device=device)
recons_tokens[batch_range, unmasked_indices] = patches[batch_range, unmasked_indices]
recons_tokens[batch_range, masked_indices] = 0
patches2 = rearrange(recons_tokens, 'b (f h w) (p1 p2 pf c) -> b c (f pf) (h p1) (w p2)',f = 8,h = 10,w=10,p1 = 16,p2 = 16,pf = 4)

In [None]:
recons_tokens.shape

In [None]:
plt.imshow(patches2[0,0,0,:,:].cpu().numpy(),cmap='gray')

In [None]:
plt.figure(figsize=(20,10))
for i in range(0,32,6):
    print(i)
    plt.subplot(6,1,i/6+1)
    plt.imshow(img[0,0,i].cpu(),cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.savefig('./visual/img_raw_0124.png',dpi=300,bbox_inches='tight')

In [None]:
plt.figure(figsize=(20,10))
for i in range(0,32,6):
    print(i)
    plt.subplot(6,1,i/6+1)
    plt.imshow(patches2[0,0,i].cpu(),cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.savefig('./visual/img_random_0124.png',dpi=300,bbox_inches='tight')

In [None]:
# 10 20 30
plt.imshow(patches2[0,0,30].cpu().detach().numpy(),cmap='gray')
plt.axis('off')
plt.savefig('visual/ours_MAE1.png',bbox_inches='tight',pad_inches=0,dpi=480)

In [None]:
plt.imshow(img[0,0,20].cpu().numpy(),cmap='gray')
plt.axis('off')
plt.savefig('visual/raw_image3.png',bbox_inches='tight',pad_inches=0,dpi=480)

In [None]:
gauss = random_gaussian_blur(patches2[0,0,30].cpu().detach().numpy())
plt.imshow(gauss,cmap='gray')

In [None]:
recons_tokens = torch.zeros(batch, num_patches, 1024, device=device)
recons_tokens[batch_range, unmasked_indices] = patches[batch_range, unmasked_indices]
recons_tokens[batch_range, masked_indices] = masked_patches
patches2 = rearrange(recons_tokens, 'b (f h w) (p1 p2 pf c) -> b c (f pf) (h p1) (w p2)',f = 8,h = 10,w=10,p1 = 16,p2 = 16,pf = 4)

In [None]:
plt.imshow(patches2[0,0,30].cpu().detach().numpy(),cmap='gray')

In [None]:
pred_pixel_values = rearrange(pred_pixel_values, 'b p (h w d) -> b p h w d', h = mae.encoder.image_patch_size, w = mae.encoder.image_patch_size,d = mae.encoder.frame_patch_size)
masked_patches = rearrange(masked_patches, 'b p (h w d) -> b p h w d', h = mae.encoder.image_patch_size, w = mae.encoder.image_patch_size,d = mae.encoder.frame_patch_size)

In [None]:
plt.subplot(1,2,1)
plt.imshow(pred_pixel_values[0,0,:,:,1].cpu().detach().numpy(),cmap = 'gray')
plt.subplot(1,2,2)
plt.imshow(masked_patches[0,0,:,:,1].cpu().detach().numpy(),cmap = 'gray')

In [None]:
rand_indices

In [57]:
!nvidia-smi

Sun May 28 07:25:54 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 470.63.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A40          Off  | 00000000:4F:00.0 Off |                    0 |
|  0%   29C    P0    70W / 300W |  30089MiB / 45634MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A40          Off  | 00000000:52:00.0 Off |                    0 |
|  0%   31C    P0    81W / 300W |  44447MiB / 45634MiB |    100%      Default |
|       

In [None]:
a=b=c=6

In [None]:
a,b

In [None]:
c