# Paper
This is the demo code for "Mental image reconstruction from human brain activity: Neural decoding of mental imagery via deep neural network-based Bayesian estimation"

## Setup
Clone the repository and download decoded features and pretrained VQGANs.

In [None]:
# Clone the repository
# TODO: URL を設定します。
!git clone XXXXX
%cd mental-image-reconstruction

# Display the eyecatch_image
from PIL import Image
eyecatch_path = './ref_images/eyecatch__mental_im_recon__+_x_ver.png'
Image.open(eyecatch_path)

In [None]:
# Download the feature data from Google Drive
import gdown
import tarfile

# Set the public file ID
# TODO: ファイル ID を変更します。
file_id = 'XXXX'
download_url = f'https://drive.google.com/uc?id={file_id}'
download_path = '/content/mental-image-reconstruction/downloaded_file.tar.gz'
# Download the tar.gz file
gdown.download(download_url, download_path, quiet=False)

# Extract the downloaded tar.gz file
with tarfile.open(download_path, 'r:gz') as tar:
    tar.extractall(path='/content/mental-image-reconstruction')

In [None]:
# Install VQGAN
!git clone https://github.com/CompVis/taming-transformers
%cd taming-transformers
!mkdir -p logs/vqgan_imagenet_f16_1024/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_1024/configs
!wget 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt'
!wget 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/configs/model.yaml'
# Modify part of the code to be compatible with PyTorch 2.x
!sed -i 's/from torch._six import string_classes/string_classes = str/' ./taming/data/utils.py
%pip install -e .
%cd ..

Install minimal required dependencies.

In [None]:
%pip install mat73 omegaconf einops ftfy regex tqdm pytorch-lightning
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import sys
import yaml

# Load demo params
with open('demo_params.yaml', 'rb') as f:  # config
    prm_demo = yaml.safe_load(f)
# Load config for imrecon
with open('config.yaml', 'rb') as f:  # config
    dt_cfg = yaml.safe_load(f)

# Set directory of taming_transformer
dir_taming_transformer = dt_cfg['file_path']['taming_transformer_dir']
sys.path.insert(0, dir_taming_transformer)

# Import required modules

# Set GPU if it's available
cudaID = "cuda:0"
DEVICE = torch.device(cudaID if torch.cuda.is_available() else "cpu")

## Load the VQGANs, VGG19 and CLIP

In [None]:
import model_loading

# load VQGAN model
config1024 = model_loading.load_config(
    dir_taming_transformer+"/logs/vqgan_imagenet_f16_1024/configs/model.yaml", display=False)
VQGANmodel1024 = model_loading.load_vqgan(
    config1024, ckpt_path=dir_taming_transformer+"/logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt").to(DEVICE)
VQGANmodel1024.eval()

# Load VGG19 model
VGGmodel_, _ = model_loading.load_VGG_model(DEVICE)

# Load CLIP models to be used.
# set CLIPmodelName_
CLIP_modelNames = dt_cfg["models"]["CLIP"]["modelnames"]
CLIP_modelTypes = dt_cfg["models"]["CLIP"]["modeltypes"]
CLIP_usedLayer = dt_cfg["models"]["CLIP"]["used_layer"]
CLIPmodelWeight_ = dt_cfg["models"]["CLIP"]["modelcoefs"]
CLIPmodel_, nameOfSubdirForCLIPfeature = model_loading.load_CLIP_model(
    CLIP_modelTypes, DEVICE)

## Set parameters
Set parameters for image reconstruction.
Select the subject (S01, S02, S03), target image for reconstruction (0 ~ 24), and reconstruction method (original, Langevin, withoutLangevin, withoutVQGAN).

In [None]:
subject = 'S02'  # select from 'S01', 'S02', 'S03'

targetID_list = [21, 20, 18, 19, 7, 14]  # select from 0 to 24

# select from 'original' (default), 'Langevin', 'withoutLangevin', 'withoutVQGAN'
reconMethod = 'original'

# Reconstruction

In [None]:
# start reconstruction
import recon_utils as utils
utils.start_reconstruction(
    subject, targetID_list, reconMethod, dt_cfg, prm_demo,
    CLIPmodel_, VGGmodel_, CLIPmodelWeight_, VQGANmodel1024, DEVICE)