In [1]:
#Conversion of ViVit from keras to pytorch
# Original keras code -> https://keras.io/examples/vision/vivit/

!pip install -qq medmnist
import medmnist

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
import numpy as np
import keras





[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/88.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━[0m [32m61.4/88.4 kB[0m [31m1.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [2]:

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32

INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11

# Hyperparameters for Adam Optimizer
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# Tubelet Embedding shape
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
        # videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )


info = medmnist.INFO[DATASET_NAME]

prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

def preprocess(frames, label):
    """Preprocess the frames tensors and parse the labels for PyTorch."""
    #frames is 4D tensor, of the form -> ( batch_size, height, width, depth)
    # depth is number of frames in each video
    frames = torch.tensor(frames)
    # Normalizing
    if frames.dtype != torch.float32:
        frames = frames.to(torch.float32) / 255.0

    frames = frames.unsqueeze(0)
    label=torch.tensor(label)
    label = label.to(torch.float32)

    return frames, label

class CustomDataset(Dataset):
    def __init__(self, videos, labels, transform=None):
        self.videos = videos
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        video = self.videos[idx]
        label = self.labels[idx]

        if self.transform:
            video, label = self.transform(video, label)

        return video, label

def prepare_dataloader(videos, labels, batch_size, loader_type="train"):
      dataset = CustomDataset(videos, labels, transform=preprocess)
      shuffle = True if loader_type == "train" else False

      dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)
      return dataloader


trainloader = prepare_dataloader(train_videos, train_labels,BATCH_SIZE, "train")
validloader = prepare_dataloader(valid_videos, valid_labels,BATCH_SIZE, "valid")
testloader = prepare_dataloader(test_videos, test_labels,BATCH_SIZE, "test")

print(len(trainloader.dataset))
print(len(trainloader))

print(f"train labels size {len(train_labels)}")

for i, data in enumerate(trainloader, 0):
    videos, labels = data
    print(f'Video batch dimensions: {videos.size()}')
    break




Downloading data from https://zenodo.org/records/10519652/files/organmnist3d.npz?download=1
971
31
train labels size 971


  self.pid = os.fork()
  self.pid = os.fork()


Video batch dimensions: torch.Size([32, 1, 28, 28, 28])


In [15]:

class TubeletEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size):
        super(TubeletEmbedding, self).__init__()
        self.projection = nn.Conv3d(
            in_channels=1,
            out_channels=embed_dim, # each token of video will have dimensionality 128
            kernel_size=patch_size,
            stride=patch_size,
            padding=0  #zero padding
        )
        self.embed_dim=embed_dim

    def forward(self, videos):
        videos = videos.to(self.projection.weight.device)
        projected_patches = self.projection(videos)
        #print(videos.size()) -- its 32,1,28,28,28 - input video size
        #print(projected_patches.size()) -- after 3D CNN, we get output matrix of shape, [32,128,3,3,3]
        flattened_patches =projected_patches.view(projected_patches.size(0), -1, self.embed_dim)

        # print(flattened_patches.size()) -- we flattend the 3*3*3 matrix into 1d layer, 27, hence final size is 32,128,27
        return flattened_patches

class PositionalEncoder(nn.Module):
    def __init__(self, embed_dim):
        super(PositionalEncoder, self).__init__()
        self.embed_dim = embed_dim

    def forward(self, encoded_tokens):
        # Encoded_tokens is of shape [batch_size, num_tokens, embed_dim]
        #encoded tokens has shape 32,27,128
        _, num_tokens, _ = encoded_tokens.shape


        position_ids = torch.arange(0, num_tokens, dtype=torch.float32, device=encoded_tokens.device).unsqueeze(1)
        position_embeddings = torch.zeros(num_tokens, self.embed_dim, device=encoded_tokens.device)
        # (num_tokens, embed_dim) - positional embeddings shape


        #sinuisoidal positional embeddings
        position_embeddings[:, 0::2] = torch.sin(position_ids * (10000 ** (-torch.arange(0, self.embed_dim, 2, device=encoded_tokens.device) / self.embed_dim)))
        position_embeddings[:, 1::2] = torch.cos(position_ids * (10000 ** (-torch.arange(1, self.embed_dim, 2, device=encoded_tokens.device) / self.embed_dim)))

        # Adding positional encodings to each token
        encoded_tokens = encoded_tokens + position_embeddings.unsqueeze(0)
        return encoded_tokens

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim,eps=1e-6)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(embed_dim,eps=1e-6)

        # Gelu activation is used in the MLP
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.GELU()
        )

    def forward(self, x):
        # Transformer block forward function
        x1 = self.norm1(x)
        attn_output, _ = self.attn(x1, x1, x1)
        x2 = attn_output + x
        x3 = self.norm2(x2)
        x3 = self.mlp(x3)
        return x3 + x2

class ViViTClassifier(nn.Module):
    def __init__(self, tubelet_embedder, positional_encoder, embed_dim=PROJECTION_DIM, num_heads=NUM_HEADS, num_classes=NUM_CLASSES, transformer_layers=NUM_LAYERS):
        super(ViViTClassifier, self).__init__()
        self.tubelet_embedder = tubelet_embedder
        self.positional_encoder = positional_encoder
        self.transformer_blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads) for _ in range(transformer_layers)])
        self.norm = nn.LayerNorm(embed_dim,eps=1e-6)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.tubelet_embedder(x)
        x = self.positional_encoder(x)
        x = self.transformer_blocks(x)
        x = self.norm(x)
        x = x.transpose(1, 2)  # Transpose for pooling layer, so output becomes, 32, 128,27
        x = self.pool(x).squeeze(-1) # after pooling , output becomes, 32,128
        x = self.classifier(x) # we do FFN of 128*11, as there are final 11 output classes
        return F.softmax(x, dim=-1)

def train(model, trainloader, validloader, optimizer, epochs,device):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            labels = labels.long()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

        # Validation step
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in validloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                labels=labels.long()
                loss = F.cross_entropy(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f"Validation Loss: {val_loss/len(validloader)}, Accuracy: {100 * correct / total}%")

def evaluate(model, testloader,device):
    model.eval()
    correct = 0
    top5_correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test accuracy: {accuracy}%")

def run_experiment():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViViTClassifier(tubelet_embedder=TubeletEmbedding(embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE),
                            positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    train(model, trainloader, validloader, optimizer, EPOCHS,device)
    evaluate(model, testloader,device)
    return model

model = run_experiment()


Epoch 1, Loss: 2.3905058445469027
Validation Loss: 2.3836456537246704, Accuracy: 9.937888198757763%
Epoch 2, Loss: 2.380719115657191
Validation Loss: 2.3821576038996377, Accuracy: 9.937888198757763%
Epoch 3, Loss: 2.3801726525829685
Validation Loss: 2.377044121424357, Accuracy: 9.937888198757763%
Epoch 4, Loss: 2.3787009023850962
Validation Loss: 2.376653552055359, Accuracy: 9.937888198757763%
Epoch 5, Loss: 2.3757751603280344
Validation Loss: 2.385567585627238, Accuracy: 9.937888198757763%
Epoch 6, Loss: 2.373362910362982
Validation Loss: 2.3742785851160684, Accuracy: 10.559006211180124%
Epoch 7, Loss: 2.3680344012475785
Validation Loss: 2.379710872968038, Accuracy: 11.180124223602485%
Epoch 8, Loss: 2.3306092216122534
Validation Loss: 2.3132792313893638, Accuracy: 26.70807453416149%
Epoch 9, Loss: 2.245137014696675
Validation Loss: 2.200348218282064, Accuracy: 32.91925465838509%
Epoch 10, Loss: 2.2107921723396546
Validation Loss: 2.178258935610453, Accuracy: 35.40372670807454%
Epoch 

In [16]:

import imageio
import io
import ipywidgets
from IPython.display import display

NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]

ground_truths = []
preds = []
videos = []

model.eval()
with torch.no_grad():
    for i, (testsample, label) in enumerate(zip(testsamples, labels)):
        # Generate gif
        testsample_np = testsample.squeeze().numpy()  # testsample is 4D (batch, channel, H, W)
        with io.BytesIO() as gif:
            imageio.mimsave(gif, (testsample_np * 255).astype("uint8"), "GIF", fps=5)
            videos.append(gif.getvalue())

        testsample = testsample.unsqueeze(0)
        output = model(testsample)
        pred = output.max(1)[1].item()
        ground_truths.append(label.item())
        preds.append(pred)


def make_box_for_grid(image_widget, fit):
    """Make a VBox to hold caption/image for demonstrating option_fit values.

    Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
    """

    if fit is not None:
        fit_str = "'{}'".format(fit)
    else:
        fit_str = str(fit)

    h = ipywidgets.HTML(value="" + str(fit_str) + "")
    boxb = ipywidgets.widgets.Box()
    boxb.children = [image_widget]

    vb = ipywidgets.widgets.VBox()
    vb.layout.align_items = "center"
    vb.children = [h, boxb]
    return vb

boxes = []
for i in range(NUM_SAMPLES_VIZ):
    ib = ipywidgets.Image(value=videos[i], width=100, height=100)
    true_class_index = int(ground_truths[i])
    pred_class_index = int(preds[i])
    true_class = info["label"][str(true_class_index)]
    pred_class = info["label"][str(pred_class_index)]
    caption = f"T: {true_class} | P: {pred_class}"

    boxes.append(make_box_for_grid(ib, caption))

grid = ipywidgets.GridBox(boxes, layout=ipywidgets.Layout(grid_template_columns="repeat(5, 200px)"))
display(grid)


GridBox(children=(VBox(children=(HTML(value="'T: pancreas | P: kidney-right'"), Box(children=(Image(value=b'GI…

In [21]:
# Final code to load the saved model /train the model from scratch
torch.save(model.state_dict(), 'ViViTClassifier.pth')
print("Saved trained model state to 'ViViTClassifier.pth'")

# To train the model, type -> python pytorch_vivit.py
# to load an existing model and do eval ->  python  pytorch_vivit.py load

def load_model(model_path, device):
    model = ViViTClassifier(TubeletEmbedding(embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE),
                            positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM))
    model.to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded model from {model_path}")
    return model

import sys

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if len(sys.argv) > 1 and sys.argv[1] == 'load':
        model = load_model('ViViTClassifier.pth', device)
        evaluate(model, testloader, device)
    else:
        model = run_experiment()

if __name__ == "__main__":
    main()


Saved trained model state to 'ViViTClassifier.pth'
