In [1]:
from timm.models.layers.helpers import to_2tuple
import timm
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn

from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

import os
import pandas as pd
import openslide
from PIL import Image
import numpy as np
import random
import time

from tqdm.notebook import trange, tqdm

In [2]:
gpu = 2

DATA_DIR = 'Data/'

args = pd.Series({
    'batch_size_per_gpu' : 16, #512,
    'num_workers': 8,
    'image_dir': os.path.join(DATA_DIR, 'train_images'),
    'train_val_split': 0.8,
})

In [None]:
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')    

cudnn.benchmark = True

CTransPath from https://github.com/Xiyue-Wang/TransPath

Make sure to import specified version of timm.

In [None]:
class ConvStem(nn.Module):

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, 
                 flatten=True, output_fmt=None):
        super().__init__()

        assert patch_size == 4
        assert embed_dim % 8 == 0

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten


        stem = []
        input_dim, output_dim = 3, embed_dim // 8
        for l in range(2):
            stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
            stem.append(nn.BatchNorm2d(output_dim))
            stem.append(nn.ReLU(inplace=True))
            input_dim = output_dim
            output_dim *= 2
        stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
        self.proj = nn.Sequential(*stem)

        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        #print(x.shape)
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

def ctranspath():
    model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False)
    return model

Look at model breakdown

In [7]:
model = ctranspath()
summary(model.cuda(), ( 3, 224, 224))

torch.Size([2, 3, 224, 224])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 12, 112, 112]             324
       BatchNorm2d-2         [-1, 12, 112, 112]              24
              ReLU-3         [-1, 12, 112, 112]               0
            Conv2d-4           [-1, 24, 56, 56]           2,592
       BatchNorm2d-5           [-1, 24, 56, 56]              48
              ReLU-6           [-1, 24, 56, 56]               0
            Conv2d-7           [-1, 96, 56, 56]           2,400
         LayerNorm-8             [-1, 3136, 96]             192
          ConvStem-9             [-1, 3136, 96]               0
          Dropout-10             [-1, 3136, 96]               0
        LayerNorm-11             [-1, 3136, 96]             192
           Linear-12              [-1, 49, 288]          27,936
          Softmax-13            [-1, 3, 49, 49]               0
          

In [37]:
model

SwinTransformer(
  (patch_embed): ConvStem(
    (proj): Sequential(
      (0): Conv2d(3, 12, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(12, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=Fa

Load pretrained weights

In [8]:
# ============ building network ... ============
model = ctranspath()
model.head = nn.Identity()
# load weights to evaluate
td = torch.load(r'ctranspath.pth')
model.load_state_dict(td['model'], strict=True)

model = model.cuda()
model.eval()

#embed_dim = model.embed_dim # returns 96 for initial conv layer but actually might be 768
embed_dim = model.layers[-1].blocks[-1].mlp.fc2.out_features # this is 768

print(f"CTransPath model built.")

CTransPath model built.


In [10]:
summary(model, ( 3, 224, 224))

torch.Size([2, 3, 224, 224])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 12, 112, 112]             324
       BatchNorm2d-2         [-1, 12, 112, 112]              24
              ReLU-3         [-1, 12, 112, 112]               0
            Conv2d-4           [-1, 24, 56, 56]           2,592
       BatchNorm2d-5           [-1, 24, 56, 56]              48
              ReLU-6           [-1, 24, 56, 56]               0
            Conv2d-7           [-1, 96, 56, 56]           2,400
         LayerNorm-8             [-1, 3136, 96]             192
          ConvStem-9             [-1, 3136, 96]               0
          Dropout-10             [-1, 3136, 96]               0
        LayerNorm-11             [-1, 3136, 96]             192
           Linear-12              [-1, 49, 288]          27,936
          Softmax-13            [-1, 3, 49, 49]               0
          

Add model head in training for the output we want

### Check passing data into model

#### Construct data loader

Load data

In [None]:
patch_df = pd.read_csv('train_patches.csv', index_col=0)

train_df = pd.read_csv(os.path.join(DATA_DIR, 'train_hot.csv'), index_col=0)

Define transforms

In [None]:
class MyRotateTransform:
    def __init__(self, angles):  #: Sequence[int]):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return transforms.functional.rotate(x, angle)
    

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(.1, 0.25, 0.5, 0.25),
    transforms.GaussianBlur(kernel_size=(9, 9)),
    transforms.RandomAdjustSharpness(sharpness_factor=2., p=0.2),
    transforms.RandomAutocontrast(p=0.5),
    MyRotateTransform(angles=[-90, 0, 90, 180]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Define dataset and data loader

In [None]:
class PatchDataset(Dataset):
    def __init__(self, patch_df, label_df, image_folder, transform=None, patch_size=256):
        self.patch_df = patch_df
        self.label_df = label_df
        self.image_folder = image_folder
        self.transform = transform
        self.patch_size = patch_size
        
        
    def __len__(self):
        return len(self.patch_df)
    
    
    def __getitem__(self, idx):
        patch_info = self.patch_df.loc[idx] # using loc, so using name of index not ordered position
        x, y = patch_info['x_coord'], patch_info['y_coord']
        image_id = patch_info['image_id']
        
        wsi = openslide.OpenSlide(os.path.join(self.image_folder, f"{image_id}.tif"))
        wsi_size = wsi.dimensions
        
        patch = wsi.read_region((x, y), level=0, size=(self.patch_size, self.patch_size))
        patch = np.array(patch)[..., :3]
        
        
        # shape torch.Size([batch_size, 256, 256, 3])
        
        if self.transform is not None:
            pil_img = Image.fromarray(patch)
            patch = self.transform(pil_img) # pil_img
        
        label = self.label_for_id(image_id)
        
        return patch, label
    
        
    def label_for_id(self, image_id):
        return self.label_df[self.label_df['image_id']==image_id]['label_cat'].iloc[0]
        

In [126]:
patch_subset = PatchDataset(patch_df=patch_df.iloc[:1000], label_df=train_df, 
                             image_folder=args.image_dir, transform=val_transform)
        
patch_loader = torch.utils.data.DataLoader(patch_subset, batch_size=args.batch_size_per_gpu, 
                                           shuffle=False, num_workers=args.num_workers, 
                                           pin_memory=True, sampler=None)

#### Pass input

In [None]:
slide_features = []
labels = []

model.eval()

for i, (inp, lbl) in enumerate(patch_loader):
    inp = inp.to(device)
    
    labels += lbl

    with torch.no_grad():
        output = model(inp)
        slide_features += output

slide_features = torch.vstack(slide_features)

In [150]:
slide_features.shape

torch.Size([1000, 768])

In [40]:
slide_features.shape

torch.Size([1000, 768])