In [1]:
import pandas as pd
from pathlib import Path
import ast
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
import numpy as np
import copy
import random
import time
import math
import sys


import torch
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import multiprocessing

import torchvision
from torch.utils.data import Dataset, DataLoader, sampler
from torchvision import models, transforms
from torchvision.transforms import functional as F
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torchvision.models.detection.mask_rcnn

In [2]:
from collections import defaultdict, deque
import datetime
import pickle
import time

import torch
import torch.distributed as dist

import errno
import os


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


def all_gather(data):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device="cuda")
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list


def reduce_dict(input_dict, average=True):
    """
    Args:
        input_dict (dict): all the values will be reduced
        average (bool): whether to do average or sum
    Reduce the values in the dictionary from all processes so that all processes
    have the averaged results. Returns a dict with the same fields as
    input_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return input_dict
    with torch.no_grad():
        names = []
        values = []
        # sort the keys so that they are consistent across processes
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.all_reduce(values)
        if average:
            values /= world_size
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))


def collate_fn(batch):
    return tuple(zip(*batch))


def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):

    def f(x):
        if x >= warmup_iters:
            return 1
        alpha = float(x) / warmup_iters
        return warmup_factor * (1 - alpha) + alpha

    return torch.optim.lr_scheduler.LambdaLR(optimizer, f)


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)


In [3]:
multiprocessing.set_sharing_strategy('file_system')

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
DATA_DIR = Path('/kaggle/input/ranzcr-clip-catheter-line-classification')

In [6]:
[x for x in DATA_DIR.glob('*.csv')]

[PosixPath('/kaggle/input/ranzcr-clip-catheter-line-classification/sample_submission.csv'),
 PosixPath('/kaggle/input/ranzcr-clip-catheter-line-classification/train_annotations.csv'),
 PosixPath('/kaggle/input/ranzcr-clip-catheter-line-classification/train.csv')]

In [7]:
sample_submission = pd.read_csv(DATA_DIR/ 'sample_submission.csv')
train_annotations = pd.read_csv(DATA_DIR/ 'train_annotations.csv')
train = pd.read_csv(DATA_DIR/ 'train.csv')

In [8]:
def get_bounding_box(cordinates):
    xmin = ymin = 99999
    xmax = ymax = -1
    for x, y  in cordinates:
        if x < xmin:
            xmin = x
        if x > xmax:
            xmax = x
        if y < ymin:
            ymin = y
        if y > ymax:
            ymax = y 
       
    return [xmin,ymin, xmax,ymax]        

In [9]:
data = pd.read_csv(DATA_DIR/ 'train_annotations.csv')

In [10]:
def convert_strlist2tuple(mask_string):
    mask_list = ast.literal_eval(mask_string)
    mask = []
    for item in mask_list:
        mask.append( (1 if item[0]<0 else item[0], 1 if item[1]<0 else item[1]))
    return mask

In [11]:
data = pd.read_csv(DATA_DIR/ 'train_annotations.csv')
data['mask'] = data['data'].apply(lambda x: convert_strlist2tuple(x))
data['bounding_box'] = data['mask'].apply(lambda x: get_bounding_box(x))

In [12]:
def draw_landmarks(StudyInstanceUID):
    _sample = data[data.StudyInstanceUID == StudyInstanceUID]
    if _sample.shape[0] == 0:
        print("Landmarks absent")
    image_file = f'{StudyInstanceUID}.jpg'
    image = Image.open(DATA_DIR / 'train'/ image_file).convert("RGB")
    mask =  Image.new('L', image.size, color =0)
    font = ImageFont.truetype('/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf',size=80)
    draw_img = ImageDraw.Draw(image)
    draw_mask = ImageDraw.Draw(mask)
    label = []
    for index, row in _sample.iterrows():
        r,g,b = np.random.randint(0,255,3)
        landmarks = row['mask']
        bounding_box = row.bounding_box
        draw_img.line(landmarks, fill=(r,g,b) ,width =5)
        draw_mask.line(landmarks, fill = int(r) ,width =5)
        draw_img.text(landmarks[int((len(landmarks)/2))], row.label, font=font, fill= (r,g,b))
        draw_img.rectangle(bounding_box, outline=(r,g,b) ,width =15)
        
    
    basewidth =512
    wpercent = (basewidth/float(image.size[0]))
    hsize = int((float(image.size[1])*float(wpercent)))
    image = image.resize((basewidth,hsize), Image.ANTIALIAS)
    mask = mask.resize((basewidth,hsize), Image.ANTIALIAS)
    return image, mask

In [13]:
img, mask = draw_landmarks(train.StudyInstanceUID.values[10])

In [14]:
_data = data.groupby(['StudyInstanceUID']).agg({'label':list, 'bounding_box':list, 'mask':list}).reset_index()

In [15]:
column_label_map = {'ETT - Abnormal': 'ETT', 'ETT - Borderline': 'ETT', 'ETT - Normal': 'ETT' ,
                                 'NGT - Abnormal': 'NGT', 'NGT - Borderline': 'NGT', 'NGT - Normal': 'NGT' ,'NGT - Incompletely Imaged':'NGT',
                                 'CVC - Abnormal': 'CVC', 'CVC - Borderline': 'CVC', 'CVC - Normal': 'CVC',
                                 'Swan Ganz Catheter Present':'SGC'}

reverse_column_label_map = {'ETT' : ['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal'],
                                    'NGT' : ['NGT - Abnormal', 'NGT - Borderline', 'NGT - Normal' ,'NGT - Incompletely Imaged'],
                                    'CVC' : ['CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal'],
                                    'SGC' : ['Swan Ganz Catheter Present']}

In [16]:
label_count = data.label.value_counts()
label_count = pd.DataFrame({'Type': label_count.index ,'Count':label_count.values})
label_count['Label'] = label_count.Type.apply(lambda x: column_label_map[x])

In [17]:
label_freq = data.label.value_counts(normalize=True) * 100
label_freq = pd.DataFrame({'Type': label_freq.index ,'Percentage':label_freq.values})
label_freq['Label'] = label_freq.Type.apply(lambda x: column_label_map[x])
label_data = label_count.merge(label_freq, how='inner', on='Type',suffixes=('', '_DROP')).filter(regex='^(?!.*_DROP)').sort_values('Label')

In [18]:
class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, mask, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            mask = mask.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
                
        return image, mask, target

In [19]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, image, mask, target):
        h, w = image.size[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        img = transforms.Resize((new_h, new_w))(image)
        if mask:
            mask = transforms.Resize((new_h, new_w))(mask)
        for index in range(len(target['boxes'])):
            xmin = int(target['boxes'][index][0]*(new_w / w))
            ymin = int(target['boxes'][index][1]*(new_h / h))
            xmax = int(target['boxes'][index][2]*(new_w / w))
            ymax = int(target['boxes'][index][3]*(new_h / h))
            target['boxes'][index][0] = xmin if xmin !=0 else 1
            target['boxes'][index][1] = ymin if ymin !=0 else 1
            target['boxes'][index][2] = xmax if xmax !=0 else 1
            target['boxes'][index][3] = ymax if ymax !=0 else 1
        return img, mask, target

In [20]:
class ToTensor(object):
    def __call__(self, image, mask, target):
        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)
        data_types = {'boxes': torch.float32, 'labels':torch.int64}
        for k in target.keys():
                target[k] = torch.as_tensor(target[k], dtype=data_types[k])
        return image, mask, target

In [21]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask,  target):
        for t in self.transforms:
            image, mask, target = t(image, mask, target)
        return image, mask, target

In [22]:
def get_transform(train):
    transforms = [Rescale((512,512)) ,ToTensor() ]
# RandomHorizontalFlip(0.3)
#                                        Normalize([0.485, 0.456, 0.406],
#                                                             [0.229, 0.224, 0.225])
    return Compose(transforms)    

In [23]:
class RanzcrDataSet(Dataset):
    def __init__(self, DATA_DIR, partition, transforms = None):
        super().__init__()
        self.DATA_DIR = DATA_DIR 
        self.partition = partition
        self.column_label_map = {'ETT - Abnormal': 1, 'ETT - Borderline': 1, 'ETT - Normal': 1 ,
                                 'NGT - Abnormal': 2, 'NGT - Borderline': 2, 'NGT - Normal': 2 ,'NGT - Incompletely Imaged':2,
                                 'CVC - Abnormal': 3, 'CVC - Borderline': 3, 'CVC - Normal': 3,
                                 'Swan Ganz Catheter Present':4}

        _data = pd.read_csv(self.DATA_DIR/ 'train_annotations.csv')
        convert_strlist2tuple = lambda x: [(*li, ) for li in ast.literal_eval(x)]
        _data['mask'] = _data['data'].apply(lambda x: convert_strlist2tuple(x))
        _data['bounding_box'] = _data['mask'].apply(lambda x: RanzcrDataSet.get_bounding_box(x))
        self.data = _data.groupby(['StudyInstanceUID']).agg({'label':list, 'bounding_box':list, 'mask':list}).reset_index()
        self.transforms = transforms

    @staticmethod
    def get_bounding_box(cordinates):
        xmin = ymin = 99999
        xmax = ymax = -1
        for x, y  in cordinates:
            if x < xmin:
                xmin = x
            if x > xmax:
                xmax = x
            if y < ymin:
                ymin = y
            if y > ymax:
                ymax = y 
        return [xmin,ymin, xmax,ymax]        

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_file = f'{self.data.StudyInstanceUID[idx]}.jpg' 
        img = Image.open(self.DATA_DIR / self.partition / img_file).convert('RGB')
        mask =  Image.new('L', img.size, color =0)
        draw_mask = ImageDraw.Draw(mask)
        for landmarks in self.data['mask'][idx]:
            draw_mask.line(landmarks, fill = 255  ,width =8)

        boxes = self.data['bounding_box'][idx]
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        labels = []
        for lb in self.data['label'][idx]:
            labels.append(self.column_label_map[lb])
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        image_id = torch.tensor([idx])
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        if self.transforms:
            img, mask, target = self.transforms(img, mask, target)
        
        mask = np.array(mask)
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[1:]
        masks = mask == obj_ids[:, None, None]
        num_objs = len(obj_ids)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        area = torch.abs((boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]))
        
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target["masks"] = masks
        target["area"] = area
        target["iscrowd"] = iscrowd
        target["image_id"] = image_id

        return img, target


In [24]:
def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [25]:
def train_one_epoch(model, optimizer, lr_scheduler, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

#     lr_scheduler = None
#     if epoch == 0:
#         warmup_factor = 1. / 1000
#         warmup_iters = min(1000, len(data_loader) - 1)

#         lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        try:
            loss_dict = model(images, targets)
        except Exception as e:
            continue
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
       
    return metric_logger

In [26]:
dataset = RanzcrDataSet(DATA_DIR,'train', get_transform(True))
dataset_test = RanzcrDataSet(DATA_DIR,'train', get_transform(False))

In [27]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [29]:
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=False, num_workers=12,
    collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=12,
    collate_fn=collate_fn)

In [30]:
model = get_model_instance_segmentation(5)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.002,
                            momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=1000,
                                               gamma=0.1)

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth


  0%|          | 0.00/170M [00:00<?, ?B/s]

In [31]:
model.to(device)

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample): 

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    # train for one epoch, printing every 1000 iterations
    train_one_epoch(model, optimizer,None, data_loader, device, epoch, print_freq=1000)
    # update the learning rate
#     lr_scheduler.step()



Epoch: [0]  [   0/8995]  eta: 22:41:24  lr: 0.002000  loss: 5.8305 (5.8305)  loss_classifier: 1.4100 (1.4100)  loss_box_reg: 0.1768 (0.1768)  loss_mask: 3.8874 (3.8874)  loss_objectness: 0.3230 (0.3230)  loss_rpn_box_reg: 0.0333 (0.0333)  time: 9.0811  data: 5.7715  max mem: 1567
Epoch: [0]  [1000/8995]  eta: 0:45:27  lr: 0.002000  loss: 0.2794 (0.4487)  loss_classifier: 0.0799 (0.1030)  loss_box_reg: 0.0837 (0.0912)  loss_mask: 0.0075 (0.0187)  loss_objectness: 0.0686 (0.1374)  loss_rpn_box_reg: 0.0113 (0.0985)  time: 0.3217  data: 0.0332  max mem: 1993
Epoch: [0]  [2000/8995]  eta: 0:39:16  lr: 0.002000  loss: 0.3115 (0.4001)  loss_classifier: 0.1083 (0.0979)  loss_box_reg: 0.1277 (0.0932)  loss_mask: 0.0077 (0.0133)  loss_objectness: 0.0435 (0.1093)  loss_rpn_box_reg: 0.0208 (0.0864)  time: 0.3111  data: 0.0305  max mem: 1996
Epoch: [0]  [3000/8995]  eta: 0:33:29  lr: 0.002000  loss: 0.2245 (0.3854)  loss_classifier: 0.0859 (0.0984)  loss_box_reg: 0.0806 (0.0962)  loss_mask: 0.0088 