# Dependencies 

github: https://github.com/mahmoodlab/HIPT

In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')


In [None]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

In [None]:
%cd /content/gdrive/MyDrive/Ben/HIPT_4K
!pip install -r requirements.txt

# Standalone HIPT_4K Model Inference

In [None]:
# Training settings
batch_size = 1
epochs = 10
lr = 3e-5
gamma = 0.

seed = 42
device = 'cpu'

In [None]:
train_dir = '/content/gdrive/MyDrive/Ben/HIPT_4K/image'
#train_dir = '/content/gdrive/MyDrive/Ben/image58'
# test_dir = '/content/gdrive/MyDrive/VinBigData/FinalProjectCV/archive/test'

In [None]:
train_list = glob.glob(os.path.join(train_dir,'*.png'))
# test_list = glob.glob(os.path.join(test_dir, '*.jpg'))


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [None]:
import pandas as pd 
train = pd.read_csv('/content/gdrive/MyDrive/Ben/HIPT_4K/train.csv')
train

In [None]:
labels = [path.split('/')[-1].split('.')[0] for path in train_list] 
Y = []
for label in labels: 
    y = train.loc[train['image_id'] == label ]['label'].values[0]
    Y.append(1) if y == "CE" else Y.append(0)

    
labels = np.array(Y) 
labels

In [None]:
random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(train_list[idx])
    ax.set_title(labels[idx])
    ax.imshow(img)

In [None]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=seed)

In [None]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
# print(f"Test Data: {len(test_list)}")

In [None]:
resize = 224
train_transforms = transforms.Compose(
    [
        #transforms.Resize((resize, resize)),
        #transforms.RandomResizedCrop(resize),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        #transforms.Resize(resize),
        #transforms.CenterCrop(resize),
        transforms.ToTensor(),
    ]
)


# test_transforms = transforms.Compose(
#     [
#         transforms.Resize(resize),
#         transforms.CenterCrop(resize),
#         transforms.ToTensor(),
#     ]
# )

In [None]:
class HIPTDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = train.loc[train['image_id'] == label]['label'].values[0]
        label = 1 if label == "CE" else 0

        return img_transformed, label

In [None]:
train_data = HIPTDataset(train_list, transform=train_transforms)
valid_data = HIPTDataset(valid_list, transform=val_transforms)
# test_data = CatsDogsDataset(test_list, transform=test_transforms)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True , num_workers =4)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True , num_workers =4)
# test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [None]:
print(len(train_data), len(train_loader))

In [None]:
print(len(valid_data), len(valid_loader))

In [None]:
from hipt_4k import HIPT_4K
from hipt_model_utils import get_vit256, get_vit4k, eval_transforms
from hipt_heatmap_utils import *
light_jet = cmap_map(lambda x: x/2 + 0.5, matplotlib.cm.jet)

pretrained_weights256 = '../Checkpoints/vit256_small_dino.pth'
pretrained_weights4k = '../Checkpoints/vit4k_xs_dino.pth'
device256 = torch.device("cpu")
device4k = torch.device("cpu")

### ViT_256 + ViT_4K loaded independently (used for Attention Heatmaps)
model256 = get_vit256(pretrained_weights=pretrained_weights256) #, device=device256)
model4k = get_vit4k(pretrained_weights=pretrained_weights4k)# , device=device4k)

### ViT_256 + ViT_4K loaded into HIPT_4K API
model = HIPT_4K(pretrained_weights256, pretrained_weights4k, device256, device4k)
model.eval()

In [None]:
region = Image.open('./image_demo/image_4k.png')
x = eval_transforms()(region).unsqueeze(dim=0)
out = model.forward(x)
print('Input Shape:', x.shape)
print('Output Shape:', out.shape)


In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = 'cpu'

In [None]:

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)


    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    writer.add_scalar("Loss/train", epoch_loss, epoch+1)
    writer.add_scalar("Acc/train", epoch_accuracy, epoch+1)
    writer.add_scalar("Loss/Validation", epoch_val_loss, epoch+1)
    writer.add_scalar("Acc/Validation", epoch_val_accuracy, epoch+1)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )


In [None]:
#writer.flush()
torch.save(model.state_dict(), "HIPT_v2.pt")
writer.close()

In [None]:
!tensorboard --logdir=runs 