In [1]:
!conda info -e

# conda environments:
#
base                     /home2/akshett.jindal/miniconda3
subba                 *  /home2/akshett.jindal/miniconda3/envs/subba



In [23]:
!pip install -q transformers torch pillow torchvision tqdm

In [2]:
VIT_PRETRAINED_MODEL = "google/vit-base-patch16-224"
BASE_DIRECTORY = "/tmp/akshett.jindal"

In [3]:
import os
HF_CACHE_DIR = os.path.join(BASE_DIRECTORY, ".huggingface_cache")

In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Hyperparameters

In [53]:
BATCH_SIZE = 32
NUM_EPOCHS = 10

# Dataset

Loading dataset and creating a DataLoader for batching and converting to tensor

In [5]:
import nsd_dataset.mind_eye_nsd_utils as menutils

image_dataset = menutils.load_image_dataset(BASE_DIRECTORY)
session_data, (trn_stim_ordering, trn_voxel_data, val_stim_ordering, val_voxel_data) = menutils.get_split_data(BASE_DIRECTORY, 1)

Total number of voxels = 15724


In [7]:
"""
image_dataset => (73000, 227, 227, 3)
session_data => (30000, 15724)
trn_stim_ordering => (27000,)
trn_voxel_data => (27000, 15724)
val_stim_ordering => (3000,)
val_voxel_data => (3000, 15724)
"""

print(f"image_dataset => {image_dataset.shape}")
print(f"session_data => {session_data.shape}")
print(f"trn_stim_ordering => {trn_stim_ordering.shape}")
print(f"trn_voxel_data => {trn_voxel_data.shape}")
print(f"val_stim_ordering => {val_stim_ordering.shape}")
print(f"val_voxel_data => {val_voxel_data.shape}")

image_dataset => (73000, 227, 227, 3)
session_data => (30000, 15724)
trn_stim_ordering => (27000,)
trn_voxel_data => (27000, 15724)
val_stim_ordering => (3000,)
val_voxel_data => (3000, 15724)


In [8]:
import torch
from torch.utils.data import Dataset
import numpy as np

class FMRIDataset(Dataset):

    def __init__(self, image_order, images, voxel_data):
        self.images = images[image_order]
        self.voxel_data = voxel_data

    def __len__(self):
        return len(self.voxel_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return {
            'image': torch.tensor(self.images[idx], dtype=torch.float),
            'fmri': torch.tensor(self.voxel_data[idx], dtype=torch.float),
        }

In [13]:
train_dataset = FMRIDataset(trn_stim_ordering, image_dataset, trn_voxel_data)
test_dataset = FMRIDataset(val_stim_ordering, image_dataset, val_voxel_data)

len(train_dataset), len(test_dataset)

(27000, 3000)

In [18]:
from torch.utils.data import DataLoader

dataloader_kwargs = {
    'batch_size': BATCH_SIZE
}

train_dataloader = DataLoader(train_dataset, shuffle=True, **dataloader_kwargs)
test_dataloader = DataLoader(test_dataset, **dataloader_kwargs)

train_dataloader, test_dataloader

(<torch.utils.data.dataloader.DataLoader at 0x14fa34e28070>,
 <torch.utils.data.dataloader.DataLoader at 0x14fa34e2ab00>)

# Model

Creating a class for our model and loading the model

In [19]:
from transformers import ViTConfig, ViTModel
import torch.nn as nn

class ViTFMRI(nn.Module):

    def __init__(self,
        vit_pretrained: str,
        final_out_dim: int,
        *args, **kwargs,
    ):
        super().__init__()

        self._vit_pretrained = vit_pretrained
        self._out_dim = final_out_dim

        self.vit = ViTModel.from_pretrained(vit_pretrained, *args, **kwargs)
        self.linear = nn.Linear(self.vit.config.hidden_size, final_out_dim)

    def forward(self, *args, **kwargs):
        vit_output = self.vit(*args, **kwargs)
        final_output = self.linear(vit_output['pooler_output'])

        return final_output

In [20]:
model = ViTFMRI(
    vit_pretrained=VIT_PRETRAINED_MODEL,
    final_out_dim=session_data.shape[1],
    cache_dir=HF_CACHE_DIR,
)
model

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTFMRI(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3

In [21]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained(
    VIT_PRETRAINED_MODEL,
    cache_dir=HF_CACHE_DIR,
)
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

# Training

In [157]:
import torch

def rankdata_ordinal(a):

    arr = torch.ravel(a)
    sorter = torch.argsort(arr, stable=True)

    inv = torch.empty(sorter.shape, dtype=torch.int)
    inv[sorter] = torch.arange(*sorter.shape, dtype=torch.int)

    result = inv + 1
    return result

def rdc_tensor(x, y, f=torch.sin, k=20, s=1/6., n=1):

    if n > 1:
        values = []
        for i in range(n):
            try:
                values.append(rdc(x, y, f, k, s, 1))
            except Exception as ex:
                pass
        return torch.median(values)

    if x.ndim == 1: x = x.reshape((-1, 1))
    if y.ndim == 1: y = y.reshape((-1, 1))

    # Copula Transformation
    cx = torch.column_stack([rankdata_ordinal(xc) for xc in x.T])/float(x.shape[0])
    cy = torch.column_stack([rankdata_ordinal(yc) for yc in y.T])/float(y.shape[0])

    # Add a vector of ones so that w.x + b is just a dot product
    O = torch.ones(cx.shape[0])
    X = torch.column_stack([cx, O])
    Y = torch.column_stack([cy, O])

    # Random linear projections
    Rx = (s/X.shape[1])*torch.randn(X.shape[1], k)
    Ry = (s/Y.shape[1])*torch.randn(Y.shape[1], k)
    X = torch.mm(X, Rx)
    Y = torch.mm(Y, Ry)

    # Apply non-linear function to random projections
    fX = f(X)
    fY = f(Y)

    # Compute full covariance matrix
    C = torch.cov(torch.hstack([fX, fY]).T)

    # Due to numerical issues, if k is too large,
    # then rank(fX) < k or rank(fY) < k, so we need
    # to find the largest k such that the eigenvalues
    # (canonical correlations) are real-valued
    k0 = k
    lb = 1
    ub = k
    while True:

        # Compute canonical correlations
        Cxx = C[:k, :k]
        Cyy = C[k0:k0+k, k0:k0+k]
        Cxy = C[:k, k0:k0+k]
        Cyx = C[k0:k0+k, :k]

        eigs = torch.linalg.eigvals(torch.mm(torch.mm(torch.linalg.pinv(Cxx), Cxy),
                                              torch.mm(torch.linalg.pinv(Cyy), Cyx)))

        # Binary search if k is too large
        if not (torch.all(torch.isreal(eigs)) and
                0 <= torch.min(torch.abs(eigs)) and
                torch.max(torch.abs(eigs)) <= 1):
            ub -= 1
            k = (ub + lb) // 2
            continue
        if lb == ub: break
        lb = k
        if ub == lb + 1:
            k = ub
        else:
            k = (ub + lb) // 2

    return torch.sqrt(torch.max(torch.abs(eigs)))

In [307]:
a1 = torch.tensor([1, 2, 3, 4, 5])
a2 = torch.tensor([5, 4, 3, 2, 1])

print(rdc_tensor(a1, a2))

del a1
del a2

tensor(1.0000)


In [154]:
import numpy as np
from scipy.stats import rankdata

def rdc(x, y, f=np.sin, k=20, s=1/6., n=1):
    """
    Computes the Randomized Dependence Coefficient
    x,y: numpy arrays 1-D or 2-D
         If 1-D, size (samples,)
         If 2-D, size (samples, variables)
    f:   function to use for random projection
    k:   number of random projections to use
    s:   scale parameter
    n:   number of times to compute the RDC and
         return the median (for stability)

    According to the paper, the coefficient should be relatively insensitive to
    the settings of the f, k, and s parameters.
    """
    if n > 1:
        values = []
        for i in range(n):
            try:
                values.append(rdc(x, y, f, k, s, 1))
            except np.linalg.linalg.LinAlgError: pass
        return np.median(values)

    if len(x.shape) == 1: x = x.reshape((-1, 1))
    if len(y.shape) == 1: y = y.reshape((-1, 1))

    # Copula Transformation
    cx = np.column_stack([rankdata(xc, method='ordinal') for xc in x.T])/float(x.size)
    cy = np.column_stack([rankdata(yc, method='ordinal') for yc in y.T])/float(y.size)

    # Add a vector of ones so that w.x + b is just a dot product
    O = np.ones(cx.shape[0])
    X = np.column_stack([cx, O])
    Y = np.column_stack([cy, O])

    # Random linear projections
    Rx = (s/X.shape[1])*np.random.randn(X.shape[1], k)
    Ry = (s/Y.shape[1])*np.random.randn(Y.shape[1], k)
    X = np.dot(X, Rx)
    Y = np.dot(Y, Ry)

    # Apply non-linear function to random projections
    fX = f(X)
    fY = f(Y)

    # Compute full covariance matrix
    C = np.cov(np.hstack([fX, fY]).T)

    # Due to numerical issues, if k is too large,
    # then rank(fX) < k or rank(fY) < k, so we need
    # to find the largest k such that the eigenvalues
    # (canonical correlations) are real-valued
    k0 = k
    lb = 1
    ub = k
    while True:

        # Compute canonical correlations
        Cxx = C[:k, :k]
        Cyy = C[k0:k0+k, k0:k0+k]
        Cxy = C[:k, k0:k0+k]
        Cyx = C[k0:k0+k, :k]

        eigs = np.linalg.eigvals(np.dot(np.dot(np.linalg.pinv(Cxx), Cxy),
                                        np.dot(np.linalg.pinv(Cyy), Cyx)))
        # Binary search if k is too large
        if not (np.all(np.isreal(eigs)) and
                0 <= np.min(eigs) and
                np.max(eigs) <= 1):
            ub -= 1
            k = (ub + lb) // 2
            continue
        if lb == ub: break
        lb = k
        if ub == lb + 1:
            k = ub
        else:
            k = (ub + lb) // 2

    return np.sqrt(np.max(eigs))

In [331]:
a1 = np.array([1, 2, 3, 4, 5])
a2 = np.array([3124, 234, 2345, 54, 543])

print(rdc(a1, a2))

del a1
del a2

0.9998408164605007


In [55]:
import torch.optim

loss_function = rdc_tensor

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model = model.to(device)

In [57]:
from tqdm.auto import tqdm

EPOCHS = []
TRAIN_LOSSES = []
TEST_LOSSES = []

for epoch_num in tqdm(range(1, NUM_EPOCHS+1), desc="Epochs", position=0):

    model.train(True)

    EPOCHS.append(epoch_num)

    running_loss = 0

    for batch_num, data_batch in enumerate(tqdm(train_dataloader, desc="Train Batch", position=1, leave=False)):

        input_images = data_batch['image']
        fmris = data_batch['fmri']

        optimizer.zero_grad()

        inputs = processor(input_images, return_tensors="pt")
        for key, value in inputs.items():
            inputs[key] = value.to(device)

        outputs = model(**inputs).cpu()

        loss = loss_function(outputs, fmris)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Train Batch:   0%|          | 0/844 [00:00<?, ?it/s]

torch.Size([32, 227, 227, 3])


In [None]:
inputs = processor(image, return_tensors="pt")
inputs

In [None]:
for k in inputs.keys():
    inputs[k] = inputs[k].to(device)

In [None]:
inputs['pixel_values'].device

In [None]:
outputs = model(**inputs)

In [None]:
outputs

In [None]:
print("pooler_output:", outputs['pooler_output'].shape)
print("last_hidden_state:", outputs['last_hidden_state'].shape)

"""
pooler_output: torch.Size([1, 768])
last_hidden_state: torch.Size([1, 197, 768])
"""