In [None]:
!pip install self-supervised -Uq

In [None]:
import os
from fastai import *
from fastai.data.all import *
from fastai.vision.all import *
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *
from torch.utils.data import Dataset
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms

In [None]:
def get_x(x): return data_dir +'/train/'+ x['id']+'.tif'

def get_dls(size, bs, df):
    
    db = DataBlock(blocks = (ImageBlock(), CategoryBlock()),
              get_x = get_x, get_y=ColReader('label'),
              splitter=ColSplitter())
    
    dls = db.dataloaders(df, bs=bs)
    return dls

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
data_dir = "../input/histopathologic-cancer-detection"

In [None]:
df = pd.read_csv('../input/histopathologic-cancer-detection/train_labels.csv')
df.head()

In [None]:
class GaussianNoise:
    """Applies random Gaussian noise to a tensor.

    The intensity of the noise is dependent on the mean of the pixel values.
    See https://arxiv.org/pdf/2101.04909.pdf for more information.

    """

    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        mu = sample.mean()
        snr = np.random.randint(low=4, high=8)
        sigma = mu / snr
        noise = torch.normal(torch.zeros(sample.shape), sigma)
        return sample + noise

In [None]:
data_transformer = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomResizedCrop(size=96, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.GaussianBlur(11),
    transforms.ToTensor(),
    GaussianNoise(),
])

In [None]:
example_image_name = 'f38a6374c348f90b587e046aac6079959adf3835.tif'
example_image_path = os.path.join(data_dir, "train/"+example_image_name)
example_image = Image.open(example_image_path)

# torch transform returns a 3 x W x H image, we only show one color channel
augmented_image_1 = data_transformer(example_image).numpy()[0]
augmented_image_2 = data_transformer(example_image).numpy()[0]

fig, axs = plt.subplots(1, 3)

axs[0].imshow(example_image)
axs[0].set_axis_off()
axs[0].set_title('Original Image')

axs[1].imshow(augmented_image_1)
axs[1].set_axis_off()
axs[1].set_title('Augmented-1')

axs[2].imshow(augmented_image_2)
axs[2].set_axis_off()
axs[2].set_title('Augmented-2')

In [None]:
batch_size = 32

dir_of_files = data_dir+"/train" 
filenames = os.listdir(dir_of_files) 
files = [f.replace(".tif","") for f in filenames]

cut = int(0.8 * len(files))

train_files = files[:cut] 
valid_files = files[cut:]

# For feature extration using 20% of train data and 10% of validation data
fe_train_len = int(0.2*len(train_files))
fe_valid_len = int(0.1*len(valid_files))

fe_train_files = train_files[:fe_train_len]
fe_valid_files = valid_files[:fe_valid_len]

print(len(fe_train_files))
print(len(fe_valid_files))

In [None]:
df['is_valid'] = False
df['is_fe'] = False
df.loc[df['id'].isin(fe_valid_files), 'is_valid'] = True
df.loc[df['id'].isin(fe_valid_files), 'is_fe'] = True
df.loc[df['id'].isin(fe_train_files), 'is_fe'] = True

df.groupby('is_valid').label.value_counts()
df.groupby('is_fe').label.value_counts()

In [None]:
size=96

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
        
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        
        self.inplanes = 64

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 , num_classes)


    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None  
   
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        
        self.inplanes = planes
        
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)           # 224x224
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)         # 112x112

        x = self.layer1(x)          # 56x56
        x = self.layer2(x)          # 28x28
        x = self.layer3(x)          # 14x14
        x = self.layer4(x)          # 7x7

        x = self.avgpool(x)         # 1x1
        x = torch.flatten(x, 1)     # remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)

        return x

def resnet_temp():
    layers=[1, 1, 1, 1]
    model = ResNet(BasicBlock, layers)
    return model

def weights_copy(custom_model, resnet18):
    # print(custom_model.state_dict)
    model_custom.conv1 = resnet18.conv1
    model_custom.bn1 = resnet18.bn1
    model_custom.maxpool = resnet18.maxpool
  
    model_custom.layer1[0] = resnet18.layer1[0]
    model_custom.layer2[0] = resnet18.layer2[0]
    model_custom.layer3[0] = resnet18.layer3[0]
    model_custom.layer4[0] = resnet18.layer4[0]

    model_custom.avgpool = resnet18.avgpool
    model_custom.fc = resnet18.fc

    return model_custom

model_custom = resnet_temp()
print(model_custom)


In [None]:
arch = "xresnet18"
encoder = models.resnet18(pretrained=True)
encoder = weights_copy(model_custom, encoder)

In [None]:
df_for_fe = df[df['is_fe'] == True]
len(df_for_fe)

In [None]:
dls = get_dls(size, batch_size, df_for_fe)

In [None]:
model = create_swav_model(encoder)
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[size,int(3/4*size)], 
                                       min_scales=[0.25,0.2],
                                       max_scales=[1.0,0.35],
                                       rotate=False, jitter=False, bw=False, blur=False) 

In [None]:
K = batch_size*2**4
cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]

In [None]:
learn = Learner(dls, model, cbs=cbs)

In [None]:
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);

In [None]:
lr, wd =1e-4, 1e-2
epochs =10
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd, pct_start=0.5, cbs=EarlyStoppingCallback(monitor='train_loss', min_delta=0.1, patience=2))

In [None]:
output_path = "./models/"
save_name = f'swav_{size}_epc{epochs}'
learn.save(save_name)
torch.save(learn.model.encoder.state_dict(), output_path+save_name+'_encoder.pth')
learn.recorder.plot_loss()

**Evaluating**

In [None]:
import albumentations as A 
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision.transforms as transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import cv2
import gc

In [None]:
img_size = 96
batch_size = 24

In [None]:
data_transformer = transforms.Compose([transforms.ToTensor()])

In [None]:
batch_size = 24

dir_of_files = data_dir+"/train" 
filenames = os.listdir(dir_of_files) 
files = [f.replace(".tif","") for f in filenames]

cut = int(0.8 * len(files))

train_files = files[:cut] 
valid_files = files[cut:]

# For downstreaming using 10% of train data and 100% of validation data
downstream_train_len = int(0.1*len(train_files))
downstream_valid_len = int(1*len(valid_files))

downstream_train_files = train_files[:downstream_train_len]
downstream_valid_files = valid_files[:downstream_valid_len]

print(len(downstream_train_files))
print(len(downstream_valid_files))

In [None]:
df['is_valid'] = False
df['is_downstream'] = False
df.loc[df['id'].isin(downstream_valid_files), 'is_valid'] = True
df.loc[df['id'].isin(downstream_valid_files), 'is_downstream'] = True
df.loc[df['id'].isin(downstream_train_files), 'is_downstream'] = True

df.groupby('is_valid').label.value_counts()
df.groupby('is_downstream').label.value_counts()

In [None]:
df_for_downstream = df[df['is_downstream'] == True]
len(df_for_downstream)

In [None]:
dls = get_dls(size, batch_size, df_for_downstream)

In [None]:
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)

In [None]:
def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(size=96, arch='resnet50', encoder_path="./models/swav_96_epc10_encoder.pth"):
    
    pretrained_encoder = torch.load(encoder_path)
    #encoder = create_encoder(arch, pretrained=False, n_in=3)
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224, device=device)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
                metrics=[accuracy], loss_func=CrossEntropyLossFlat())
    return learn

In [None]:
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
    learn = create_learner(size, arch, encoder_path)
    learn.unfreeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc

In [None]:
acc = []
runs = 1
for i in range(runs): acc += [finetune(96, epochs=50, arch='resnet50', encoder_path='./models/swav_96_epc10_encoder.pth')]