In [1]:
from __future__ import print_function, division
import cv2
import os
import torch
import json
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import transforms as T
import utils
from torchvision.transforms import functional as F
from PIL import Image
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

LABEL_ENUM = {0: "nrml", 1: "benign", 2: "malg"} 

In [2]:
class BusiDataset(Dataset):
    """ GB classification dataset. """
    def __init__(self, img_dir, df, labels, to_blur=True, blur_kernel_size=(1,1), sigma=0, img_transforms=None):
        self.img_dir = img_dir
        self.transforms = img_transforms
        d = []
        for label in labels:
            key, cls = label.split(",")
            val = df[key]
            val["filename"] = key
            val["label"] = int(cls)
            d.append(val)
        self.df = d
        self.sigma = sigma
        self.to_blur = to_blur
        self.blur_kernel_size = blur_kernel_size

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Get the image
        filename = self.df[idx]["filename"]
        img_name = os.path.join(self.img_dir, filename)
        image = cv2.imread(img_name)
        if self.to_blur:
            image = cv2.GaussianBlur(image, self.blur_kernel_size, self.sigma)
        if self.transforms:
            img = self.transforms(image)
        label = torch.as_tensor(self.df[idx]["label"], dtype=torch.int64)
        #cv2.imwrite(filename, image)
        print
        return img, label, filename


class GbRawDataset(Dataset):
    """ GB classification dataset. """
    def __init__(self, img_dir, df, labels, img_transforms=None):
        self.img_dir = img_dir
        self.transforms = img_transforms
        d = []
        for label in labels:
            key, cls = label.split(",")
            val = df[key]
            val["filename"] = key
            val["label"] = int(cls)
            d.append(val)
        self.df = d

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Get the image
        filename = self.df[idx]["filename"]
        img_name = os.path.join(self.img_dir, filename)
        image = cv2.imread(img_name)
        if self.transforms:
            img = self.transforms(image)
        label = torch.as_tensor(self.df[idx]["label"], dtype=torch.int64)
        #cv2.imwrite(filename, image)
        print
        return img, label, filename


def crop_image(image, box, p):
    x1, y1, x2, y2 = box
    cropped_image = image[int((1-p)*y1):int((1+p)*y2), \
                            int((1-p)*x1):int((1+p)*x2)]
    return cropped_image


class GbDataset(Dataset):
    """ GB classification dataset. """
    def __init__(self, img_dir, df, labels, is_train=True, to_blur=True, blur_kernel_size=(65,65), sigma=0, p=0.15, img_transforms=None):
        self.img_dir = img_dir
        self.transforms = img_transforms
        self.to_blur = to_blur
        self.blur_kernel_size = blur_kernel_size
        self.sigma = sigma
        self.is_train = is_train
        d = []
        for label in labels:
            key, cls = label.split(",")
            val = df[key]
            val["filename"] = key
            val["label"] = int(cls)
            d.append(val)
        self.df = d
        self.p = p

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Get the image
        filename = self.df[idx]["filename"]
        img_name = os.path.join(self.img_dir, filename)
        image = cv2.imread(img_name)
        if self.to_blur:
            image = cv2.GaussianBlur(image, self.blur_kernel_size, self.sigma)
        image = crop_image(image, self.df[idx]["Gold"], self.p)
        if self.transforms:
            image = self.transforms(image)
        label = torch.as_tensor(self.df[idx]["label"], dtype=torch.int64)
        return image, label, filename
        """
        # Get the roi bbox
        num_objs = len(self.df[idx]["Boxes"])
        crps = [orig]
        labels = [label]
        for i in range(num_objs):
            bbs = self.df[idx]["Boxes"][i]
            crp_img = crop_image(image, bbs, 0.1)
            #stack the predicted rois as different samples
            if self.transforms:
                crp_img = self.transforms(crp_img)
            crps.append(crp_img)
            labels.append(label)
        if num_objs == 0:
            #use the original img if no bbox predicted
            #orig = self.transforms(image)
            orig = orig.unsqueeze(0)
            label = label.unsqueeze(0)
        else:
            orig = torch.stack(crps, 0)
            label = torch.stack(labels, 0)
        """


class GbCropDataset(Dataset):
    """ GB classification dataset. """
    def __init__(self, img_dir, df, labels, to_blur=True, blur_kernel_size=(65,65), sigma=16, p=0.15, img_transforms=None):
        self.img_dir = img_dir
        self.transforms = img_transforms
        self.to_blur = to_blur
        self.blur_kernel_size = (4*sigma+1, 4*sigma+1)#blur_kernel_size
        self.sigma = sigma
        self.p = p
        d = []
        for label in labels:
            key, cls = label.split(",")
            val = df[key]
            val["filename"] = key
            val["label"] = int(cls)
            d.append(val)
        self.df = d

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Get the image
        filename = self.df[idx]["filename"]
        img_name = os.path.join(self.img_dir, filename)
        image = cv2.imread(img_name)
        #plt.imshow(image,cmap='gray')
        #plt.show()
        if self.to_blur:
            image = cv2.GaussianBlur(image, self.blur_kernel_size, self.sigma)
        #print(df.values())
        #print()
        orig = crop_image(image, self.df[idx]["Gold"], self.p)
        #orig = torch.from_numpy(orig).long()
        if self.transforms:
            orig = self.transforms(orig)
        # Get the roi bbox
        num_objs = len(self.df[idx]["Boxes"])
        label = torch.as_tensor(self.df[idx]["label"], dtype=torch.int64)
        crps = []
        labels = []
        for i in range(num_objs):
            bbs = self.df[idx]["Boxes"][i]
            crp_img = crop_image(image, bbs, self.p)
            #stack the predicted rois as different samples
            if self.transforms:
                crp_img = self.transforms(crp_img)
            crps.append(crp_img)
            labels.append(label)
        if num_objs == 0:
            #use the original img if no bbox predicted
            #orig = self.transforms(image)
            orig = orig.unsqueeze(0)
            label = label.unsqueeze(0)
        else:
            orig = torch.stack(crps, 0)
            label = torch.stack(labels, 0)
        return orig, label, filename

In [3]:
def collate_fn(batch):
    data_list, label_list, file_names = [], [], []
    for _data,_label,_filename in batch:
        data_list.append(_data)
        label_list.append(_label)
        file_names.append(_filename)
    return data_list, label_list, file_names

In [4]:
if __name__ == "__main__":
    VAL_IMG_DIR = "C:/Users/Lakshmi vara prasad/Downloads/GBCU/imgs"
    VAL_JSON = "C:/Users/Lakshmi vara prasad/Downloads/GBCU/roi_pred.json"
    labels = []
    with open("C:/Users/Lakshmi vara prasad/Downloads/GBCU/train.txt", "r") as f:
        for e in f.readlines():
            labels.append(e.strip())
    with open(VAL_JSON, "r") as f:
        df = json.load(f)
    img_transforms = T.Compose([T.Resize((224,224)), T.ToTensor()])
    dataset = GbCropDataset(VAL_IMG_DIR, df, labels, img_transforms = img_transforms)
    loader = DataLoader(dataset, batch_size=32, collate_fn=utils.collate_fn)
    images, labels, filename = next(iter(loader))
    print(labels)
    print(images[0].size())
    print(filename)


(tensor([0]), tensor([0]), tensor([1]), tensor([0, 0]), tensor([1]), tensor([1]), tensor([1]), tensor([1, 1]), tensor([0, 0]), tensor([1]), tensor([0]), tensor([0]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1, 1]), tensor([1, 1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]))
torch.Size([1, 3, 224, 224])
('im00049.jpg', 'im00748.jpg', 'im01229.jpg', 'im00730.jpg', 'im00203.jpg', 'im01031.jpg', 'im00868.jpg', 'im00300.jpg', 'im00422.jpg', 'im00497.jpg', 'im00134.jpg', 'im01150.jpg', 'im00662.jpg', 'im00061.jpg', 'im01237.jpg', 'im00984.jpg', 'im00095.jpg', 'im01155.jpg', 'im00303.jpg', 'im00719.jpg', 'im01097.jpg', 'im00741.jpg', 'im01254.jpg', 'im01193.jpg', 'im00824.jpg', 'im00772.jpg', 'im00707.jpg', 'im00059.jpg', 'im00088.jpg', 'im00075.jpg', 'im00918.jpg', 'im00437.jpg')


In [5]:
!pip install neptune-client




[notice] A new release of pip available: 22.2.2 -> 22.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [6]:
from __future__ import print_function, division
import argparse
import os
import json
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import transforms as T
import utils
from torch.optim.lr_scheduler import StepLR
from skimage import io, transform
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from dataloader import GbDataset, GbRawDataset, GbCropDataset
from models import GbcNet 
# import neptune logger
import neptune.new as neptune
# Set plot style
import matplotlib.pyplot as plt
plt.style.use('ggplot')
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
def parse():
    parser = argparse.ArgumentParser(description='Process arguments')
    parser.add_argument('--img_dir', dest="img_dir", default="C:/Users/Lakshmi vara prasad/Downloads/GBCU/imgs")
    parser.add_argument('--set_dir', dest="set_dir", default="C:/Users/Lakshmi vara prasad/Downloads/GBCU")
    parser.add_argument('--train_set_name', dest="train_set_name", default="train.txt")
    parser.add_argument('--test_set_name', dest="test_set_name", default="test.txt")
    parser.add_argument('--meta_file', dest="meta_file", default="C:/Users/Lakshmi vara prasad/Downloads/GBCU/roi_pred.json")
    parser.add_argument('--epochs', dest="epochs", default=10, type=int)
    parser.add_argument('--lr', dest="lr", default=5e-3, type=float)
    parser.add_argument('--height', dest="height", default=224, type=int)
    parser.add_argument('--width', dest="width", default=224, type=int)
    parser.add_argument('--no_roi', action='store_true')
    parser.add_argument('--pretrain', action='store_true')
    parser.add_argument('--load_model', action='store_true')
    parser.add_argument('--load_path', dest="load_path", default="C:/Users/Lakshmi vara prasad/Downloads/gbcnet.pth")
    parser.add_argument('--save_dir', dest="save_dir", default="C:/Users/Lakshmi vara prasad/Download/outputs")
    parser.add_argument('--save_name', dest="save_name", default="gbcnet_1")
    parser.add_argument('--optimizer', dest="optimizer", default="sgd")
    parser.add_argument('--batch_size', dest="batch_size", default=32, type=int)
    parser.add_argument('--att_mode', dest="att_mode", default="1")
    parser.add_argument('--va', action="store_true")

    args, unknown = parser.parse_known_args()
    return args

In [8]:
def main(args):
    
    transforms = []
    transforms.append(T.Resize((args.width, args.height)))
    #transforms.append(T.RandomHorizontalFlip(0.25))
    transforms.append(T.ToTensor())
    img_transforms = T.Compose(transforms)
    
    val_transforms = T.Compose([T.Resize((args.width, args.height)),\
                                T.ToTensor()])

    with open(args.meta_file, "r") as f:
        df = json.load(f)

    train_labels = []
    t_fname = os.path.join(args.set_dir, args.train_set_name)
    with open(t_fname, "r") as f:
        for line in f.readlines():
            train_labels.append(line.strip())
    val_labels = []
    v_fname = os.path.join(args.set_dir, args.test_set_name)
    with open(v_fname, "r") as f:
        for line in f.readlines():
            val_labels.append(line.strip())
    if args.no_roi:
        train_dataset = GbRawDataset(args.img_dir, df, train_labels, img_transforms=img_transforms)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5)
        val_dataset = GbRawDataset(args.img_dir, df, val_labels, img_transforms=val_transforms)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=5)
    else:
        val_dataset = GbCropDataset(args.img_dir, df, val_labels, to_blur=False, img_transforms=val_transforms)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=5)

    net = GbcNet(num_cls=3, pretrain=args.pretrain, att_mode=args.att_mode) 

    if args.load_model:
        net.load_state_dict(torch.load(args.load_path))
    net.net = net.net.float()#.cuda()

    params = [p for p in net.parameters() if p.requires_grad]
   
    total_params = sum(p.numel() for p in net.parameters())
    print("Total Param: ", total_params)

    criterion = nn.CrossEntropyLoss()

    if args.optimizer == "sgd":
        optimizer = optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=0.0005)
    else:
        optimizer = optim.Adam(params, lr=args.lr)
    lr_sched = StepLR(optimizer, step_size=5, gamma=0.8)
    
    os.makedirs(args.save_dir, exist_ok=True)

    train_loss = []

    for epoch in range(args.epochs):
        if not args.no_roi:
            if args.va:
                if epoch <10:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, blur_kernel_size=(65,65), sigma=16, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=5)#, collate_fn=utils.collate_fn)
                elif epoch >=10 and epoch <15:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, blur_kernel_size=(33,33), sigma=8, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
                elif epoch >=15 and epoch <20:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, blur_kernel_size=(17,17), sigma=4, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
                elif epoch >=20 and epoch <25:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, blur_kernel_size=(9,9), sigma=2, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
                elif epoch >=25 and epoch <30:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, blur_kernel_size=(5,5), sigma=1, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
                else:
                    train_dataset = GbDataset(args.img_dir, df, train_labels, to_blur=False, img_transforms=img_transforms)
                    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
            else:
                train_dataset = GbDataset(args.img_dir, df, train_labels, to_blur=False, img_transforms=img_transforms)
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)#, collate_fn=utils.collate_fn)
        
        running_loss = 0.0
        total_step = len(train_loader)
        for images, targets, fnames in train_loader:
            #images, targets = images.float().cuda(), targets.cuda()
            images, targets = images.float(), targets
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs.cpu(), targets.cpu())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_loss.append(running_loss/total_step)
       
        y_true, y_pred = [], []
        with torch.no_grad():
            net.eval()
            for images, targets, fname in val_loader:
                #images, targets = images.float().cuda(), targets.cuda()
                images, targets = images.float(), targets
                if not args.no_roi:
                    images = images.squeeze(0)
                    outputs = net(images)
                    _, pred = torch.max(outputs, dim=1)
                    pred_label = torch.max(pred)
                    pred_idx = pred_label.item()
                    pred_label = pred_label.unsqueeze(0)
                    y_true.append(targets.tolist()[0][0])
                    y_pred.append(pred_label.item())
                else:
                    outputs = net(images)
                    _, pred = torch.max(outputs, dim=1)
                    pred_idx = pred.item()
                    y_true.append(targets.tolist()[0])
                    y_pred.append(pred.item())
            acc = accuracy_score(y_true, y_pred)
            cfm = confusion_matrix(y_true, y_pred)
            spec = (cfm[0][0] + cfm[0][1] + cfm[1][0] + cfm[1][1])/(np.sum(cfm[0]) + np.sum(cfm[1]))
            sens = cfm[2][2]/np.sum(cfm[2])
            print('Epoch: [{}/{}] Train-Loss: {:.4f} Val-Acc: {:.4f} Val-Spec: {:.4f} Val-Sens: {:.4f}'\
                    .format(epoch+1, args.epochs, train_loss[-1], acc, spec, sens))

            _name = "%s_epoch_%s.pth"%(args.save_name, epoch)
            save_path = os.path.join(args.save_dir, _name)
            torch.save(net.state_dict(), save_path)

        net.train()
        #lr_sched.step()

In [None]:
if __name__ == "__main__":
    args = parse()
    main(args)

Total Param:  26903627
