In [16]:
from omegaconf import OmegaConf

In [38]:
config_path = '/home/leehu/project/brain2image/brain2image/pretrains/ldm/label2img/config.yaml'

config = OmegaConf.load(config_path)

In [18]:
config.model.params.first_stage_config.params.ddconfig.ch_mult 

[1, 2, 4]

In [1]:
import torch

ckpt = torch.load('/NFS/Users/hwlee/checkpoint.pth')

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 23.70 GiB total capacity; 1.97 GiB already allocated; 3.56 MiB free; 1.97 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [11]:
ckpt['config'].data_dir

'/DATA1/NSD/data_ae'

In [1]:
import torch

ckpt = torch.load('/home/leehu/project/brain2image/brain2image/pretrains/ldm/label2img/model.ckpt')

In [40]:
ckpt['state_dict']['model.diffusion_model.time_embed.0.weight'].shape

torch.Size([1024, 256])

In [41]:
import importlib

def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

In [42]:
config.model.params.unet_config.params.use_time_cond = True # true
config.model.params.unet_config.params.global_pool = False # false

In [43]:
model = instantiate_from_config(config.model)

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 708.32 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [44]:
model

LatentDiffusion(
  (model): DiffusionWrapper(
    (diffusion_model): UNetModel(
      (time_embed): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): SiLU()
        (2): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (time_embed_condtion): Sequential(
        (0): Conv1d(77, 38, kernel_size=(1,), stride=(1,))
        (1): Conv1d(38, 1, kernel_size=(1,), stride=(1,))
        (2): Linear(in_features=512, out_features=1024, bias=True)
      )
      (input_blocks): ModuleList(
        (0): TimestepEmbedSequential(
          (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): TimestepEmbedSequential(
          (0): ResBlock(
            (in_layers): Sequential(
              (0): GroupNorm32(32, 256, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
            (h_upd): Identity()
      

In [1]:

import os, sys
import numpy as np
import torch
from eval_metrics import get_similarity_metric
from dataloader import NSD_Dataset
from ldm.ldm_for_fmri import fLDM
from einops import rearrange
from PIL import Image
import torchvision.transforms as transforms
from config import *
import wandb
import datetime
import argparse
import random

def to_image(img):
    if img.shape[-1] != 3:
        img = rearrange(img, 'c h w -> h w c')
    img = 255. * img
    return Image.fromarray(img.astype(np.uint8))

def channel_last(img):
    if img.shape[-1] == 3:
        return img
    return rearrange(img, 'c h w -> h w c')

def normalize(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img

def wandb_init(config):
    wandb.init( project="mind-vis",
                group='eval',
                anonymous="allow",
                config=config,
                reinit=True)

def get_eval_metric(samples, avg=True):
    metric_list = ['mse', 'pcc', 'ssim', 'psm']
    res_list = []
    
    gt_images = [img[0] for img in samples]
    gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c')
    samples_to_run = np.arange(1, len(samples[0])) if avg else [1]
    for m in metric_list:
        res_part = []
        for s in samples_to_run:
            pred_images = [img[s] for img in samples]
            pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c')
            res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m)
            res_part.append(np.mean(res))
        res_list.append(np.mean(res_part))     
    res_part = []
    for s in samples_to_run:
        pred_images = [img[s] for img in samples]
        pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c')
        res = get_similarity_metric(pred_images, gt_images, 'class', None, 
                        n_way=50, num_trials=1000, top_k=1, device='cuda')
        res_part.append(np.mean(res))
    res_list.append(np.mean(res_part))
    res_list.append(np.max(res_part))
    metric_list.append('top-1-class')
    metric_list.append('top-1-class (max)')
    return res_list, metric_list

def get_args_parser():
    parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False)
    # project parameters
    parser.add_argument('--root', type=str, default='.')
    parser.add_argument('--dataset', type=str, default='GOD')

    return parser

In [None]:
args = get_args_parser()
args = args.parse_args()
root = args.root
target = args.dataset
#model_path = os.path.join(root, 'pretrains', f'{target}', 'finetuned.pth')
model_path = os.path.join('/home/leehu/project/brain2image/brain2image/results/generation/09/checkpoint.pth')
sd = torch.load(model_path, map_location='cpu')
config = sd['config']

seed = config.seed

# Python의 random 시드 설정
random.seed(seed)

# Numpy 시드 설정
np.random.seed(seed)

# PyTorch 시드 설정
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# update paths
config.root_path = root
# config.kam_path = os.path.join(root, 'data/Kamitani/npz')
# config.bold5000_path = os.path.join(root, 'data/BOLD5000')
config.pretrain_mbm_path = os.path.join('/NFS/Users/hwlee/checkpoint.pth')
config.pretrain_gm_path = os.path.join(root, 'pretrains/ldm/label2img')
print(config.__dict__)

output_path = os.path.join(config.root_path, 'results', 'eval',  
                '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

img_transform_test = transforms.Compose([
    normalize, transforms.Resize((256, 256)), 
    channel_last
])

# if target == 'GOD':
#     _, dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, 
#             fmri_transform=torch.FloatTensor, image_transform=img_transform_test, 
#             subjects=config.kam_subs, test_category=config.test_category)
# elif target == 'BOLD5000':
#     _, dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, 
#             fmri_transform=torch.FloatTensor, image_transform=img_transform_test, 
#             subjects=config.bold5000_subs)
# else:
#     raise NotImplementedError

dataset_test = NSD_Dataset(data_dir=config.data_dir, subject_num=config.subject_num, type='val')

num_voxels = dataset_test.num_voxels
print(len(dataset_test))
# prepare pretrained mae 
pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu')
# create generateive model
generative_model = fLDM(pretrain_mbm_metafile, num_voxels,
            device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger,
            ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond)
generative_model.model.load_state_dict(sd['model_state_dict'])
print('load ldm successfully')
state = sd['state']
grid, samples = generative_model.generate(dataset_test, config.num_samples, 
            config.ddim_steps, config.HW, limit=None, state=state) # generate 10 instances
grid_imgs = Image.fromarray(grid.astype(np.uint8))

os.makedirs(output_path, exist_ok=True)
grid_imgs.save(os.path.join(output_path,f'./samples_test.png'))

wandb_init(config)
wandb.log({f'summary/samples_test': wandb.Image(grid_imgs)})
metric, metric_list = get_eval_metric(samples, avg=True)
metric_dict = {f'summary/pair-wise_{k}':v for k, v in zip(metric_list[:-2], metric[:-2])}
metric_dict[f'summary/{metric_list[-2]}'] = metric[-2]
metric_dict[f'summary/{metric_list[-1]}'] = metric[-1]
print(metric_dict)
wandb.log(metric_dict)


In [31]:
import torch
from einops import rearrange, repeat
from torchvision.utils import make_grid
a = torch.randn([1,3,256,256])
b = torch.randn([1,3,256,256])
sample = torch.cat([a, b], dim=0)
all_samples.append(sample)

a = torch.randn([1,3,256,256])
b = torch.randn([1,3,256,256])
sample = torch.cat([a, b], dim=0)
all_samples.append(sample)

a = torch.randn([1,3,256,256])
b = torch.randn([1,3,256,256])
sample = torch.cat([a, b], dim=0)
all_samples.append(sample)

In [32]:
grid = torch.stack(all_samples, 0)

In [33]:
grid = rearrange(grid, 'n b c h w -> (n b) c h w')

In [34]:
grid = make_grid(grid, nrow=4)

In [36]:
grid.shape

torch.Size([3, 1292, 1034])