<a href="https://colab.research.google.com/github/prachimodi-142/CancerDetection/blob/master/Active_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import utils, models,datasets
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import time
import torch.nn.functional as F
from groupy.gconv.pytorch_gconv import P4MConvZ2, P4MConvP4M

import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision

In [None]:
def new_getitem(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target, index


In [None]:
datasets.ImageFolder.__getitem__ = new_getitem

In [None]:
class block(nn.Module):
    def __init__(self, in_planes, intermediate_planes, identity_downsample=None, stride=1):
        super().__init__()
        self.expansion = 4
        self.conv1 = P4MConvP4M(in_planes, intermediate_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(intermediate_planes)
        self.conv2 = P4MConvP4M(intermediate_planes, intermediate_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(intermediate_planes)
        self.conv3 = P4MConvP4M(intermediate_planes,intermediate_planes * self.expansion,kernel_size=1,stride=1,padding=0,bias=False)
        self.bn3 = nn.BatchNorm3d(intermediate_planes * self.expansion)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample
        self.stride = stride

    def forward(self, x):
        identity = x.clone()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x


class ResNet(nn.Module):
    def __init__(self, block, layers, image_channels, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 23
        self.conv1 = P4MConvZ2(3, 23, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(23)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))

        # Essentially the entire ResNet architecture are in these 4 lines below
        self.layer1 = self._make_layer(block, layers[0], intermediate_plane=23, stride=1)
        self.layer2 = self._make_layer(block, layers[1], intermediate_plane=45, stride=2)
        self.layer3 = self._make_layer(block, layers[2], intermediate_plane=91, stride=2)
        self.layer4 = self._make_layer(block, layers[3], intermediate_plane=181, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(181*8* 4, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

    def _make_layer(self, block, num_residual_blocks, intermediate_plane, stride):
        identity_downsample = None
        layers = []

        # Either if we half the input space for ex, 56x56 -> 28x28 (stride=2), or channels changes
        # we need to adapt the Identity (skip connection) so it will be able to be added
        # to the layer that's ahead
        if stride != 1 or self.in_planes != intermediate_plane * 4:
            identity_downsample = nn.Sequential(
                P4MConvP4M(self.in_planes,intermediate_plane * 4,kernel_size=1,stride=stride,bias=False),
            nn.BatchNorm3d(intermediate_plane * 4)

            )

        layers.append(block(self.in_planes, intermediate_plane, identity_downsample, stride))

        # The expansion size is always 4 for ResNet 50,101,152
        self.in_planes = intermediate_plane * 4

        # For example for first resnet layer: 256 will be mapped to 64 as intermediate layer,
        # then finally back to 256. Hence no identity downsample is needed, since stride = 1,
        # and also same amount of channels.
        for i in range(num_residual_blocks - 1):
            layers.append(block(self.in_planes, intermediate_plane))

        return nn.Sequential(*layers)

In [None]:
def least_confidence_query(model, device, data_loader, query_size):

    confidences = []
    indices = []
    model.eval()
    loopt = tqdm(data_loader)

    with torch.no_grad():
        for b, (data, _,idx) in enumerate(loopt):
            logits = model(data.to(device))
            probabilities = F.softmax(logits, dim=1) #The probability of all the classes adds up to one

            # Keep only the top class confidence for each sample
            most_probable = torch.max(probabilities, dim=1)[0]
            confidences.extend(most_probable.cpu().tolist()) #storing the confidence values
            indices.extend(idx.tolist())

    conf = np.asarray(confidences)
    ind = np.asarray(indices)
    sorted_pool = np.argsort(conf) #sorting the confidence values and returning their indices

    return ind[sorted_pool][0:query_size] #The first bracket reorders the "ind" array

In [None]:
def query_the_oracle(model, device, dataset, tracker, query_size,batch_size):

    unlabeled_idx = np.nonzero(tracker)[0] #returns the indices of images not labled

    # Select a pool of samples to query from
    pool_loader = DataLoader(dataset, batch_size=batch_size,sampler=SubsetRandomSampler(unlabeled_idx),num_workers=24)

    sample_idx = least_confidence_query(model, device, pool_loader, query_size) # Return the indices corresponding to the lowest confidences

    # Query the samples, one at a time
    for sample in sample_idx:
      tracker[sample] = 0
    return tracker

In [None]:
def train(model, device, train_loader, optimizer, criterion):

    model.train()

    epoch_loss = 0

    for batch in train_loader:

        data, target, _ = batch
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss

In [None]:
def test(model, device, criterion,test_loader):
          model.eval()
          loopt = tqdm(test_loader)
          tst_corr = 0
          with torch.no_grad():
            for b, (X_test, y_test,_) in enumerate(loopt):
                X_test, y_test = X_test.to(device), y_test.to(device)
                # Apply the model
                y_val = model(X_test)

                loss = criterion(y_val, y_test)
                loopt.set_postfix(loss=loss.item())

                # Tally the number of correct predictions
                predicted = torch.max(y_val.data, 1)[1]
                tst_corr = tst_corr + (predicted == y_test).sum()



          print(f'Test accuracy: {tst_corr*100/len(test_loader.dataset):.3f}%')
          return tst_corr*100/len(test_loader.dataset)

In [None]:
train_transform = transforms.Compose([transforms.Resize((96,96)),
                                      transforms.ColorJitter(brightness=.5, saturation=.25,hue=.1, contrast=.5),
                                      transforms.RandomAffine(10, (0.05, 0.05)),
                                      transforms.RandomHorizontalFlip(.5),
                                      transforms.RandomVerticalFlip(.5),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                                                           [0.27210504, 0.31001145, 0.2918652])
        ])

test_transform = transforms.Compose([transforms.Resize((96,96)),
            transforms.ToTensor(),
            transforms.Normalize([0.6716241, 0.48636872, 0.60884315],
                                 [0.27210504, 0.31001145, 0.2918652])
        ])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_data = datasets.ImageFolder("PCam/Pcam_Train/Pcam_Train",transform = train_transform)
print(len(train_data))


In [None]:
test_data = datasets.ImageFolder("PCam/Pcam_Test_192/Pcam_Test_192",transform = test_transform)
print(len(test_data))

In [None]:
criterion = nn.CrossEntropyLoss()
torch.manual_seed(42)
train_loader = DataLoader(train_data, batch_size=48, shuffle=True,num_workers=24)
test_loader = DataLoader(test_data, batch_size=48, shuffle=True,num_workers=24)
model = ResNet(block, [3, 4, 6, 3], 3, 2)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
tracker = np.ones((len(train_data)))

In [None]:
best_acc = 0.0
for i in range(10):
    print(f"Query "+str(i))
    tracker = query_the_oracle(model, device, train_data, tracker, query_size=6000,batch_size=48)
    labeled_idx = np.where(tracker == 0)[0]
    print(len(labeled_idx))
    labeled_loader = DataLoader(train_data, batch_size=48,sampler=SubsetRandomSampler(labeled_idx),num_workers=24)
    epochs  = 5 if i>4 else 3
    for k in range(epochs):
        train_loss = train(model, device, labeled_loader, optimizer, criterion)
        current_test_acc = test(model, device,criterion, test_loader)
        if current_test_acc>best_acc:
            torch.save(model.state_dict(), "resnet50_active_"+str(current_test_acc)+"_"+str(i)+".pt")
            best_acc = current_test_acc