# *VinBigData 2 classes classifer*
![](https://images.medicinenet.com/images/article/main_image/chest-x-ray.jpg)

* This notebook is next step after this kernel :
[https://www.kaggle.com/khaledmgamal/detectron2-vinbigdata](http://)

* It takes the submission file from the above kernel and apply 2 classes classifer to classify each image as having thoracic abnormalities or normal as post processing step to decrease the false positive .      

In [None]:
import gc
import os
from pathlib import Path
import random
import sys

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import scipy as sp


import matplotlib.pyplot as plt
import seaborn as sns

# --- models ---
from sklearn import preprocessing
from sklearn.model_selection import KFold

import torch

In [None]:
from dataclasses import dataclass, field
from typing import Dict, Any, Tuple, Union, List


@dataclass
class Flags:
    # General
    debug: bool = True
    outdir: str = "results/det"
    device: str = "cuda:0"

    # Data config
    imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256"
    # split_mode: str = "all_train"  # all_train or valid20
    seed: int = 111
    target_fold: int = 0  # 0~4
    label_smoothing: float = 0.0
    # Model config
    model_name: str = "resnet18"
    model_mode: str = "normal"  # normal, cnn_fixed supported
    # Training config
    epoch: int = 20
    batchsize: int = 8
    valid_batchsize: int = 16
    num_workers: int = 4
    snapshot_freq: int = 5
    scheduler_type: str = ""
    scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
    scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"])
    aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
    mixup_prob: float = -1.0  # Apply mixup augmentation when positive value is set.

    def update(self, param_dict: Dict) -> "Flags":
        # Overwrite by `param_dict`
        for key, value in param_dict.items():
            if not hasattr(self, key):
                raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
            setattr(self, key, value)
        return self

In [None]:
flags_dict = {
    "debug": False,  # Change to True for fast debug run!
    "outdir": "results/tmp_debug",
    # Data
    "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256",
    # Model
    "model_name": "resnet18",
    # Training
    "num_workers": 4,
    "epoch": 15,
    "batchsize": 16,
    "scheduler_type": "CosineAnnealingWarmRestarts",
    "scheduler_kwargs": {"T_0": 28125},  # 15000 * 15 epoch // (batchsize=8)
    "scheduler_trigger": [1, "iteration"],
    "aug_kwargs": {
        "HorizontalFlip": {"p": 0.5},
        "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
        "RandomBrightnessContrast": {"p": 0.5},
        "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
        "Blur": {"blur_limit": [3, 7], "p": 0.5},
        "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
        "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
    }
}

In [None]:
import dataclasses

# args = parse()
print("torch", torch.__version__)
flags = Flags().update(flags_dict)
print("flags", flags)


# --- Read data ---
inputdir = Path("/kaggle/input")
datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection"
imgdir = inputdir / flags.imgdir_name

# Read in the data CSV files
train = pd.read_csv(datadir / "train.csv")

In [None]:
import pickle
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
import pandas as pd
#from detectron2.structures import BoxMode
from tqdm import tqdm


def get_vinbigdata_dicts(
    imgdir: Path,
    train_df: pd.DataFrame,
    train_data_type: str = "original",
    use_cache: bool = True,
    debug: bool = True,
    target_indices: Optional[np.ndarray] = None,
):
    '''
    parameters: 
              imgdir: the path to image directory 
              train_df: the dataframe that contians the images id and the bounding boxes
    Returns           
               list of dict (dataset_dicts) where each dict contains following:

               -file_name: file name of the image.
               -image_id: id of the image, index is used here.
               -height: height of the image.
               -width: width of the image.
               -annotation: This is the ground truth annotation data for object detection, which contains following
               -bbox: bounding box pixel location with shape (n_boxes, 4)
               -category_id: class label id for each bounding box, with shape (n_boxes,)
    '''       
    debug_str = f"_debug{int(debug)}"
    train_data_type_str = f"_{train_data_type}"
    cache_path = Path(".") / f"dataset_dicts_cache{train_data_type_str}{debug_str}.pkl"
    if not use_cache or not cache_path.exists():
        print("Creating data...")
        train_meta = pd.read_csv(imgdir / "train_meta.csv")
        if debug:
            train_meta = train_meta.iloc[:500]  # For debug....

        # Load 1 image to get image size.
        image_id = train_meta.loc[0, "image_id"]
        image_path = str(imgdir / "train" / f"{image_id}.png")
        image = cv2.imread(image_path)
        resized_height, resized_width, ch = image.shape
        print(f"image shape: {image.shape}")

        dataset_dicts = []
        for index, train_meta_row in tqdm(train_meta.iterrows(), total=len(train_meta)):
            record = {}

            image_id, height, width = train_meta_row.values
            filename = str(imgdir / "train" / f"{image_id}.png")
            record["file_name"] = filename
            record["image_id"] = image_id
            record["height"] = resized_height
            record["width"] = resized_width
            objs = []
            for index2, row in train_df.query("image_id == @image_id").iterrows():
                # print(row)
                # print(row["class_name"])
                # class_name = row["class_name"]
                class_id = row["class_id"]
                if class_id == 14:
                    # It is "No finding"
                    # This annotator does not find anything, skip.
                    pass
                else:
                    # bbox_original = [int(row["x_min"]), int(row["y_min"]), int(row["x_max"]), int(row["y_max"])]
                    h_ratio = resized_height / height
                    w_ratio = resized_width / width
                    bbox_resized = [
                        int(row["x_min"]) * w_ratio,
                        int(row["y_min"]) * h_ratio,
                        int(row["x_max"]) * w_ratio,
                        int(row["y_max"]) * h_ratio,
                    ]
                    obj = {
                        "bbox": bbox_resized,
                        #"bbox_mode": BoxMode.XYXY_ABS,
                        "category_id": class_id,
                    }
                    objs.append(obj)
            record["annotations"] = objs
            dataset_dicts.append(record)
        with open(cache_path, mode="wb") as f:
            pickle.dump(dataset_dicts, f)

    print(f"Load from cache {cache_path}")
    with open(cache_path, mode="rb") as f:
        dataset_dicts = pickle.load(f)
    if target_indices is not None:
        dataset_dicts = [dataset_dicts[i] for i in target_indices]
    return dataset_dicts

############################################################################################################

def get_vinbigdata_dicts_test(
    imgdir: Path, test_meta: pd.DataFrame, use_cache: bool = True, debug: bool = True,
):
    '''
    parameters: 
              imgdir: the path to image directory 
              test_meta: the dataframe that contians the images id and the original image size
    Returns           
               list of dict (dataset_dicts) where each dict contains following:

               -file_name: file name of the image.
               -image_id: id of the image, index is used here.
               -height: height of the image.
               -width: width of the image.
    '''     
    
    debug_str = f"_debug{int(debug)}"
    cache_path = Path(".") / f"dataset_dicts_cache_test{debug_str}.pkl"
    if not use_cache or not cache_path.exists():
        print("Creating data...")
        # test_meta = pd.read_csv(imgdir / "test_meta.csv")
        if debug:
            test_meta = test_meta.iloc[:500]  # For debug....

        # Load 1 image to get image size.
        image_id = test_meta.loc[0, "image_id"]
        image_path = str(imgdir / "test" / f"{image_id}.png")
        image = cv2.imread(image_path)
        resized_height, resized_width, ch = image.shape
        print(f"image shape: {image.shape}")

        dataset_dicts = []
        for index, test_meta_row in tqdm(test_meta.iterrows(), total=len(test_meta)):
            record = {}

            image_id, height, width = test_meta_row.values
            filename = str(imgdir / "test" / f"{image_id}.png")
            record["file_name"] = filename
            # record["image_id"] = index
            record["image_id"] = image_id
            record["height"] = resized_height
            record["width"] = resized_width
            # objs = []
            # record["annotations"] = objs
            dataset_dicts.append(record)
        with open(cache_path, mode="wb") as f:
            pickle.dump(dataset_dicts, f)

    print(f"Load from cache {cache_path}")
    with open(cache_path, mode="rb") as f:
        dataset_dicts = pickle.load(f)
    return dataset_dicts

# Mixup data augmentation
What is mixup?
As the name kind of suggests, the authors of the mixup article propose to train the model on a mix of the pictures of the training set. Let’s say we’re on CIFAR10 for instance, then instead of feeding the model the raw images, we take two (which could be in the same class or not) and do a linear combination of them: in terms of tensor it’s

new_image = t * image1 + (1-t) * image2
where t is a float between 0 and 1. Then the target we assign to that image is the same combination of the original targets:

new_target = t * target1 + (1-t) * target2
assuming your targets are one-hot encoded (which isn’t the case in pytorch usually). And that’s as simple as this.

Refer to this link : [https://forums.fast.ai/t/mixup-data-augmentation/22764](http://)

In [None]:
import numpy
import six
import torch
from torch.utils.data.dataset import Dataset


class DatasetMixin(Dataset):

    def __init__(self, transform=None):
        self.transform = transform

    def __getitem__(self, index):
        """Returns an example or a sequence of examples."""
        #print("datamix___getitem")
        if torch.is_tensor(index):
            index = index.tolist()
        if isinstance(index, slice):
            current, stop, step = index.indices(len(self))
            return [self.get_example_wrapper(i) for i in
                    six.moves.range(current, stop, step)]
        elif isinstance(index, list) or isinstance(index, numpy.ndarray):
            return [self.get_example_wrapper(i) for i in index]
        else:
            return self.get_example_wrapper(index)

    def __len__(self):
        """Returns the number of data points."""
        raise NotImplementedError

    def get_example_wrapper(self, i):
        """Wrapper of `get_example`, to apply `transform` if necessary"""
        #print("datamix___get_example_wrapper")
        
        example = self.get_example(i)
        #print("get_example_wrapper_len",len(example))
        if self.transform:
            #print("transform")
            example = self.transform(example)
        return example

    def get_example(self, i):
        """Returns the i-th example.

        Implementations should override it. It should raise :class:`IndexError`
        if the index is invalid.

        Args:
            i (int): The index of the example.

        Returns:
            The i-th example.

        """
        raise NotImplementedError
        
import cv2
import numpy as np


class VinbigdataTwoClassDataset(DatasetMixin):
    def __init__(self, dataset_dicts, image_transform=None, transform=None, train: bool = True,
                 mixup_prob: float = -1.0, label_smoothing: float = 0.0):
        super(VinbigdataTwoClassDataset, self).__init__(transform=transform)
        self.dataset_dicts = dataset_dicts
        self.image_transform = image_transform
        self.train = train
        self.mixup_prob = mixup_prob
        self.label_smoothing = label_smoothing

    def _get_single_example(self, i):
        #print('_get_single_example')
        d = self.dataset_dicts[i]
        filename = d["file_name"]

        img = cv2.imread(filename)
        if self.image_transform:
            img = self.image_transform(img)
        img = torch.tensor(np.transpose(img, (2, 0, 1)).astype(np.float32))

        if self.train:
            label = int(len(d["annotations"]) > 0)  # 0 normal, 1 abnormal
            if self.label_smoothing > 0:
                if label == 0:
                    return img, float(label) + self.label_smoothing
                else:
                    return img, float(label) - self.label_smoothing
            else:
                return img, float(label)
        else:
            # Only return img
            return img, None

    def get_example(self, i):
        #print('get_example')
        
        img, label = self._get_single_example(i)
        if self.mixup_prob > 0. and np.random.uniform() < self.mixup_prob:
            j = np.random.randint(0, len(self.dataset_dicts))
            p = np.random.uniform()
            img2, label2 = self._get_single_example(j)
            img = img * p + img2 * (1 - p)
            if self.train:
                label = label * p + label2 * (1 - p)

        if self.train:
            label_logit = torch.tensor([1 - label, label], dtype=torch.float32)
            return img, label_logit
            #return img, torch.argmax(label_logit)
        
        else:
            # Only return img
            return img

    def __len__(self):
        return len(self.dataset_dicts)        

In [None]:
dataset_dicts = get_vinbigdata_dicts(imgdir, train, debug=False)


In [None]:
dataset = VinbigdataTwoClassDataset(dataset_dicts)


In [None]:
import albumentations as A


class Transform:
    def __init__(
        self, hflip_prob: float = 0.5, ssr_prob: float = 0.5, random_bc_prob: float = 0.5
    ):
        self.transform = A.Compose(
            [
                A.HorizontalFlip(p=hflip_prob),
                A.ShiftScaleRotate(
                    shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=ssr_prob
                ),
                A.RandomBrightnessContrast(p=random_bc_prob),
            ]
        )

    def __call__(self, image):
        image = self.transform(image=image)["image"]
        return image

In [None]:
aug_dataset = VinbigdataTwoClassDataset(dataset_dicts, image_transform=Transform())

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(20,20))
for index in range(3):
    img, label = dataset[index]
    axes[index].imshow(img.cpu().numpy().transpose((1, 2, 0)) / 255.)
    axes[index].set_title(f"{index}-th image: label {label}")
    

# Install timm library 
refer to this link for more info : [https://rwightman.github.io/pytorch-image-models/](http://)


In [None]:
!pip install timm

In [None]:
from torch import nn
from torch.nn import Linear


class CNNFixedPredictor(nn.Module):
    def __init__(self, cnn: nn.Module, num_classes: int = 2):
        super(CNNFixedPredictor, self).__init__()
        self.cnn = cnn
        self.lin = Linear(cnn.num_features, num_classes)
        #self.lin = Linear(cnn.fc.in_features, num_classes)

        print("cnn.num_features", cnn.num_features)
        #print("cnn.num_features", cnn.fc.in_features)
        # We do not learn CNN parameters.
        # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
        for param in self.cnn.parameters():
            param.requires_grad = True

    def forward(self, x):
        feat = self.cnn(x)
        return self.lin(feat)


# Creating the resnet18 model to be trained as classifer :
refer to this link for more info :[ https://rwightman.github.io/pytorch-image-models/feature_extraction/](http://)


In [None]:
import timm
timm_model = timm.create_model(flags.model_name, pretrained=True, num_classes=0, in_chans=3)
model=CNNFixedPredictor(timm_model, num_classes=2)

In [None]:
def model_paramters_report(model):
  pytorch_total_params = sum(p.numel() for p in model.parameters())

  pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  print(f"number of parameters = {pytorch_total_params} number of trainable parameters = {pytorch_total_trainable_params} ")
  

In [None]:
model_paramters_report(model)

In [None]:
model

In [None]:
#for name, param in model.named_parameters():
    #if 'cnn.layer4' in str(name) or 'cnn.layer3.1' in str(name) :
    #print(name)
    #param.requires_grad=True 
     

In [None]:
from typing import Dict

import albumentations as A


class Transform:
    def __init__(self, aug_kwargs: Dict):
        self.transform = A.Compose(
            [getattr(A, name)(**kwargs) for name, kwargs in aug_kwargs.items()]
        )

    def __call__(self, image):
        image = self.transform(image=image)["image"]
        return image

In [None]:
model

In [None]:
import torch
import torch.nn.functional as F


def cross_entropy_with_logits(input, target, dim=-1):
    loss = torch.sum(- target * F.log_softmax(input, dim), dim)
    return loss.mean()

def accuracy(y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Computes multi-class classification accuracy"""
    assert y.shape[:-1] == t.shape, f"y {y.shape}, t {t.shape} is inconsistent."
    pred_label = torch.max(y.detach(), dim=-1)[1]
    count = t.nelement()
    correct = (pred_label == t).sum().float()
    acc = correct / count
    return acc


def accuracy_with_logits(y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Computes multi-class classification accuracy"""
    assert y.shape == t.shape
    gt_label = torch.max(t.detach(), dim=-1)[1]
    return accuracy(y, gt_label)




def predict_proba(model,data_loader):
    #device: torch.device = next(self.parameters()).device
    y_list = []
    model.to(device).eval()
    with torch.no_grad():
        for batch in tqdm(data_loader):
            if isinstance(batch, (tuple, list)):
                # Assumes first argument is "image"
                batch = batch[0].to(device)
            else:
                batch = batch.to(device)
            y = model(batch)
            y = torch.softmax(y, dim=-1)
            y_list.append(y)
    pred = torch.cat(y_list)
    return pred



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad ], lr=0.001)
loss_function=nn.BCEWithLogitsLoss()
lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=28125)

In [None]:
from tqdm import tqdm_notebook

def epoch_training(model,train_loader,optimizer,scheduler,criterion):
    model.train()
    running_loss=0
    tqdm_train = tqdm(train_loader, total=int(len(train_loader)),position=0, leave=True)
    for counter,(input_,target) in enumerate(tqdm_train):
        input_.to(device)
        target.to(device)
        output=model(input_)
        optimizer.zero_grad()
        loss=criterion(output,target)
        loss.backward()
        optimizer.step()
        #lr_sched.step()
        tqdm_train.set_postfix(loss=loss.item())
        running_loss+=loss.item()
    epoch_loss=running_loss/len(train_loader)  
    
    return epoch_loss

def validation_epoch(model,valid_loader,criterion,accuracy):
    model.eval()
    running_acc=0
    tqdm_val = tqdm(valid_loader, total=int(len(valid_loader)),position=0, leave=True)
    running_loss=0
    labels=[]
    model_outputs=[]
    for counter,(input_,target) in enumerate(tqdm_val):
        with torch.no_grad():
            input_.to(device)
            target.to(device)
            output=model(input_)
            model_outputs+=torch.max(output.detach(), dim=-1)[1].numpy().tolist()
            labels+=torch.max(target.detach(), dim=-1)[1].numpy().tolist()
            loss=criterion(output,target).item()
            acc=accuracy(output,target)
            tqdm_val.set_postfix(loss=loss,acc=acc)    
            running_acc+=acc
            running_loss+=loss
    epoch_loss=running_loss/len(valid_loader)  
    epoch_acc= running_acc/len(valid_loader)
    
    return epoch_loss,epoch_acc,model_outputs,labels
from sklearn import metrics

def model_training(model,train_loader,valid_loader,optimizer,scheduler,criterion,accuracy,epochs):
    best_val=np.inf 
    for epoch in range(epochs):
        epoch_loss=epoch_training(model,train_loader,optimizer,scheduler,criterion)
        print(f"Epoch={epoch} , Loss={epoch_loss}")
        
        if valid_loader!= None:
           val_epoch_loss,val_epoch_acc,model_outputs,labels=validation_epoch(model,valid_loader,criterion,accuracy)
           print(f"Epoch={epoch} , val Loss={val_epoch_loss} ,Val Accuracy={val_epoch_acc}")
           print(metrics.confusion_matrix(labels, model_outputs).T)
           print(metrics.classification_report(labels, model_outputs, digits=3))
           metrics_=metrics.classification_report(labels, model_outputs, digits=3,output_dict=True)
           acc=metrics_['accuracy']
           f1_macro_avg=metrics_['macro avg']['f1-score']
           f1_weighted_avg=metrics_['weighted avg']['f1-score']
           if val_epoch_loss < best_val:
              print(f"---------saving model---------with val loss = {val_epoch_loss}")
              best_val=val_epoch_loss
              checkpoint = {
                            'epoch': epoch,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'Val Loss' : val_epoch_loss,
                            'Val Accuracy' : acc,
                            'f1 macro avg': f1_macro_avg,
                            'f1 weighted avg':f1_weighted_avg

                        }
              torch.save(checkpoint, 'vinbigdata_checkpoint.pt')

# Stratified K Fold  
Applying Stratified K fold cross validation to check model performance because the dataset is relatively small   

In [None]:
'''
from sklearn.model_selection import StratifiedKFold
from torch import nn, optim
from torch.utils.data.dataloader import DataLoader

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed)
# skf.get_n_splits(None, None)
y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts])
split_inds = list(skf.split(dataset_dicts, y))

for fold in range(len(split_inds)):
    train_inds, valid_inds = split_inds[fold]  # 0th fold
    train_dataset = VinbigdataTwoClassDataset(
        [dataset_dicts[i] for i in train_inds],
        image_transform=Transform(flags.aug_kwargs),
        mixup_prob=flags.mixup_prob,
        label_smoothing=flags.label_smoothing,
    )
    valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds])
    train_loader = DataLoader(
        train_dataset,
        batch_size=flags.batchsize,
        num_workers=flags.num_workers,
        shuffle=True,
        pin_memory=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=flags.valid_batchsize,
        num_workers=flags.num_workers,
        shuffle=False,
        pin_memory=True,
    )
    timm_model = timm.create_model(flags.model_name, pretrained=True, num_classes=0, in_chans=3)
    model=CNNFixedPredictor(timm_model, num_classes=2)
    
    model_training(model,train_loader,valid_loader,optimizer,lr_sched,
                   cross_entropy_with_logits,
                   accuracy_with_logits,
                   1
                   )
'''                   

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data.dataloader import DataLoader

train_dicts,val_dicts = train_test_split(dataset_dicts, test_size=0.20, random_state=42)


train_dataset = VinbigdataTwoClassDataset(
        train_dicts,
        #image_transform=Transform(flags.aug_kwargs),
        mixup_prob=flags.mixup_prob,
        label_smoothing=flags.label_smoothing,
    )

train_loader = DataLoader(
        train_dataset,
        batch_size=flags.batchsize,
        num_workers=flags.num_workers,
        shuffle=True,
        pin_memory=True,
    )

valid_dataset = VinbigdataTwoClassDataset(val_dicts)


valid_loader = DataLoader(
        valid_dataset,
        batch_size=flags.valid_batchsize,
        num_workers=flags.num_workers,
        shuffle=False,
        pin_memory=True,
    )



# Training the model

In [None]:
'''
model_training(model,train_loader,valid_loader,optimizer,lr_sched,
                   #loss_function,
                   cross_entropy_with_logits,
                   accuracy_with_logits,
                   5
                   )
'''

# Load the trained model

In [None]:
checkpoint = torch.load('../input/vinbigdata-2-class-classifier/vinbigdata_checkpoint.pt')
model.load_state_dict(checkpoint['state_dict'])

# Model evaluation using F1 Score 

In [None]:
from sklearn import metrics

epoch_loss,epoch_acc,model_outputs,labels=validation_epoch(model,valid_loader,cross_entropy_with_logits,
               accuracy_with_logits)
print(f"Epoch={0} , Loss={epoch_loss} ,Accuracy={epoch_acc}")
print(metrics.confusion_matrix(labels, model_outputs).T)
print(metrics.classification_report(labels, model_outputs, digits=3))

In [None]:
test_meta = pd.read_csv(inputdir / "vinbigdata-testmeta" / "test_meta.csv")
dataset_dicts_test = get_vinbigdata_dicts_test(imgdir, test_meta, debug=False)
test_dataset = VinbigdataTwoClassDataset(dataset_dicts_test, train=False)
test_loader = DataLoader(
    test_dataset,
    batch_size=flags.valid_batchsize,
    num_workers=flags.num_workers,
    shuffle=False,
    pin_memory=True,
)

In [None]:
test_predicts=predict_proba(model,test_loader)

In [None]:
sub=pd.read_csv('../input/vinbigdata/submission (11).csv')


In [None]:
print(f"number of anomaly Xray images = {torch.argmax(test_predicts,-1).sum()}")

In [None]:
test_predicts

In [None]:
sub['PredictionString']
predicts=[]
for bounding_box,class_ in  zip(list(sub['PredictionString']),torch.argmax(test_predicts,-1).numpy().tolist()):
    if class_ == 0:
       predicts.append('14 1 0 0 1 1')
    else:
       predicts.append(bounding_box)
                

In [None]:
len(predicts)

In [None]:
sub['PredictionString']=predicts

In [None]:
sub

In [None]:
sub.to_csv('sub_02.csv',index=False)

In [None]:
from IPython.display import FileLink
FileLink(r'sub_02.csv')