In [1]:
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt

from PIL import Image


sys.path.append('..')
from mae import models_mae

In [None]:
imagenet_mean = np.array([0.49377912])
imagenet_std = np.array([0.23812352])
checkpoint = None
def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(arch='mae_medmnist_chestmnist'):
    # build model
    model = getattr(models_mae, arch)()

    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()



### Load a MAE model


In [None]:
model_mae = prepare_model( 'mae_vit_large_patch16')
#改成其他的模型也可以
print('Model loaded.')

## 准备数据集

In [None]:
from Chestmnist_Dataset import compute_mean_std
from Chestmnist_Dataset import ChestMNISTDataset
from torch.utils.data import DataLoader
from torchvision import transforms

npz_path = "C:/Users/15899/.medmnist/chestmnist.npz"

# 计算均值和标准差
mean, std = compute_mean_std(npz_path)
print("Mean:", mean)
print("Std:", std)

# 设置 transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# 创建 Dataset
train_dataset = ChestMNISTDataset(npz_path, split='train', transform=transform)
val_dataset = ChestMNISTDataset(npz_path, split='val', transform=transform)
test_dataset = ChestMNISTDataset(npz_path, split='test', transform=transform)

batch_size = 32
nw = min([os.cpu_count(),batch_size if batch_size > 1 else 1, 4]) #number of workers
print('Using {} dataloader workers every process'.format(nw))
# 创建 Dataloader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw,
                          collate_fn=train_dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=nw,
                        collate_fn=val_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw,
                         collate_fn=test_dataset.collate_fn)

## 准备预训练

In [None]:
from torchsummary import summary
summary(model_mae,(8,3,224,224)) #查看模型中每层结构

In [None]:
import math
base_learning_rate = 1.5e-4
weight_decay = 5e-2
warmup_epoch = 100
total_epoch = 1600
optim = torch.optim.Adam(model_mae.parameters(), lr=base_learning_rate * batch_size / 256, betas=(0.9,0.95),
                         weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8),0.5 * (math.cos(epoch / total_epoch *math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

In [None]:
from tqdm import tqdm_notebook as tqdm

model_path = 'chestmnist_mae_checkpoint.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mask_ratio = 0.75
train_loss =[]
val_loss =[]
test_loss =[]
best_loss = 100
for e in range(total_epoch):
    model_mae.train()
    losses = []
    train_step = len(train_loader)
    with tqdm(total=train_step,desc=f'Epoch {e+1}/{total_epoch}',postfix=dict,miniterval=0.3) as pbar:
        for step,(imgs,_) in enumerate(iter(train_loader)):

            imgs = imgs.to(device)
            loss, predicted_img, mask = model_mae(imgs.float())
            optim.zero_grad()
            loss.backward()
            optim.step()
            losses.append(loss.item())

            pbar.set_postfix(**{'Train Loss': np.mean(losses)})
            pbar.update(1)

    lr_scheduler.step()
    avg_loss = sum(losses) / len(losses)
    train_loss.append(avg_loss)
    print(f"In Epoch {e},average training loss is {avg_loss:.4f}")

    '''visualize the first 16 predicted images on val dataset'''
    model_mae.eval()
    val_step = len(val_loader)
    with torch.no_grad():
        losses = []
        with tqdm(total=val_step,desc=f'Epoch {e+1}/{total_epoch}',postfix=dict,miniterval=0.3) as pbar:
            for step,(imgs,_) in enumerate(iter(val_loader)):
                imgs = imgs.to(device)
                loss, predicted_val_img, mask = model_mae(imgs.float())
                losses.append(loss.item())
            avg_loss = sum(losses) / len(losses)
            train_loss.append(avg_loss)
            print(f"In Epoch {e},average val loss is {avg_loss:.4f}")
    if avg_loss < best_loss:
        best_loss = avg_loss
        print(f"In Epoch {e},the best val loss is {best_loss:.4f},saving the model")
        torch.save(model_mae.state_dict(), model_path)



In [None]:
##保存loss
import pickle
data = {'train_loss':train_loss,'val_loss':val_loss,'test_loss':test_loss}
with open('./pretrain_loss.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
import pickle
import matplotlib.pyplot as plt

## 加载数据
with open('./pretrain_loss.pkl', 'wb') as f:
    data = pickle.load(f)

train_loss = data['train_loss']
val_loss = data['val_loss']
test_loss = data['test_loss']


#创建 x 轴（假设每个列表的长度相同）
epochs = range(1, len(train_loss) + 1)

#设置绘图风格
plt.style.use('seaborn')

#创建一个包含两个子图的图表：一个用于损失，另一个用于准确率
fig, (ax1) = plt.subplots(1,1,figsize = (10,8))

#绘制训练和验证损失
ax1.plot(epochs, train_loss,'b' ,label='Training loss')
ax1.plot(epochs, val_loss,'r' ,label='Validation loss')
ax1.plot(epochs, test_loss,'g' ,label='Test loss')
ax1.set_title('Training and validation loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

##调整布局以防止重叠
plt.tight_layout()

plt.show()
