In [1]:
import numpy as np
import pandas as pd

In [2]:
!pip install pycocotools

Collecting pycocotools
  Downloading pycocotools-2.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading pycocotools-2.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (426 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m426.2/426.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: pycocotools
Successfully installed pycocotools-2.0.7


In [3]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import RetinaNet
from torchvision.datasets import CocoDetection
from pycocotools.coco import COCO
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os


In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = CocoDetection(root='/kaggle/input/coco-2017-dataset/coco2017/train2017', annFile='/kaggle/input/coco-2017-dataset/coco2017/annotations/instances_train2017.json', transform=transform)
val_dataset = CocoDetection(root='/kaggle/input/coco-2017-dataset/coco2017/val2017', annFile='/kaggle/input/coco-2017-dataset/coco2017/annotations/instances_val2017.json', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)), drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)), drop_last=True)

loading annotations into memory...
Done (t=24.65s)
creating index...
index created!
loading annotations into memory...
Done (t=0.98s)
creating index...
index created!


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        targets = torch.tensor(targets)
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

focal_loss = FocalLoss()


In [6]:
class Backbone(nn.Module):
    def __init__(self, model_name='resnet50'):
        super(Backbone, self).__init__()
        if model_name == 'resnet50':
            resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self, x):
        return self.backbone(x)


In [7]:
class FeaturePyramidNetwork(nn.Module):
    def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
        super(FeaturePyramidNetwork, self).__init__()

        self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)
        self.P7 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)

    def forward(self, C3, C4, C5):
        P5_x = self.P5_1(C5)
        P5_upsampled = F.interpolate(P5_x, scale_factor=2, mode="nearest")
        P5 = self.P5_2(P5_x)

        P4_x = self.P4_1(C4)
        P4_x = P4_x + P5_upsampled
        P4_upsampled = F.interpolate(P4_x, scale_factor=2, mode="nearest")
        P4 = self.P4_2(P4_x)

        P3_x = self.P3_1(C3)
        P3_x = P3_x + P4_upsampled
        P3 = self.P3_2(P3_x)

        P6 = self.P6(C5)
        P7 = self.P7(F.relu(P6))

        return [P3, P4, P5, P6, P7]


In [8]:
class SubNet(nn.Module):
    def __init__(self, num_classes, num_anchors):
        super(SubNet, self).__init__()
        self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, num_anchors * num_classes, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        return self.conv5(x)

class RetinaNet(nn.Module):
    def __init__(self, num_classes, backbone):
        super(RetinaNet, self).__init__()
        self.backbone = backbone
        fpn_sizes = [256, 512, 1024, 2048]  # Sizes of C3, C4, C5
        self.fpn = FeaturePyramidNetwork(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])
        self.num_anchors = 9
        self.num_classes = num_classes
        self.classification_head = SubNet(num_classes, self.num_anchors)
        self.regression_head = SubNet(4, self.num_anchors)

    def forward(self, x):
        C3, C4, C5 = self.backbone(x)
        features = self.fpn(C3, C4, C5)
        classifications = [self.classification_head(f) for f in features]
        regressions = [self.regression_head(f) for f in features]
        return classifications, regressions


In [9]:
model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
num_classes = 91  # COCO has 80 classes + background

model.head.classification_head.num_classes = num_classes

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)


Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth" to /root/.cache/torch/hub/checkpoints/retinanet_resnet50_fpn_coco-eeacb38b.pth
100%|██████████| 130M/130M [00:00<00:00, 147MB/s]  


RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=0.0)


In [10]:
from tqdm import tqdm

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    model.train()
    running_loss = 0.0
    data_loader = tqdm(data_loader, desc=f"Epoch {epoch}", unit="batch")
    
    for i, (images, targets) in enumerate(data_loader):
        images = list(image.to(device) for image in images)
        
        valid_targets = []
        for t in targets:
            if len(t) == 0:
                valid_targets.append({'boxes': torch.zeros((0, 4), dtype=torch.float32).to(device),
                                      'labels': torch.zeros((0,), dtype=torch.int64).to(device)})
                continue
            
            boxes = torch.tensor([obj['bbox'] for obj in t], dtype=torch.float32)
            labels = torch.tensor([obj['category_id'] for obj in t], dtype=torch.int64)
            
            if boxes.ndim == 1:
                boxes = boxes.unsqueeze(0)
            
            boxes[:, 2:] += boxes[:, :2]
            
            keep = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])
            valid_targets.append({
                'boxes': boxes[keep].to(device),
                'labels': labels[keep].to(device)
            })

        loss_dict = model(images, valid_targets)


        classification_loss = loss_dict.get('classification', torch.tensor(0.0, device=device, requires_grad=True))
        bbox_regression_loss = loss_dict.get('bbox_regression', torch.tensor(0.0, device=device, requires_grad=True))


        loss = classification_loss + bbox_regression_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        data_loader.set_postfix(loss=running_loss / (i + 1))

        if i % print_freq == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
def evaluate(model, data_loader, device):
    model.eval()
    running_loss = 0.0
    data_loader = tqdm(data_loader, desc="Evaluating", unit="batch")
    
    with torch.no_grad():
        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            valid_targets = []
            for t in targets:
                if len(t) == 0:
                    valid_targets.append({'boxes': torch.zeros((0, 4), dtype=torch.float32).to(device),
                                          'labels': torch.zeros((0,), dtype=torch.int64).to(device)})
                    continue
                
                boxes = torch.tensor([obj['bbox'] for obj in t], dtype=torch.float32)
                labels = torch.tensor([obj['category_id'] for obj in t], dtype=torch.int64)
                
                if boxes.ndim == 1:
                    boxes = boxes.unsqueeze(0)
                
                boxes[:, 2:] += boxes[:, :2]
                

                keep = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])
                valid_targets.append({
                    'boxes': boxes[keep].to(device),
                    'labels': labels[keep].to(device)
                })
            
            outputs = model(images)

            running_loss += sum(loss.item() for loss in outputs.values())
            data_loader.set_postfix(loss=running_loss / (i + 1))
            print(outputs)


In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


In [12]:
num_epochs = 10

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=50)
    evaluate(model, val_loader, device)


Epoch 0:   0%|          | 1/29571 [00:03<25:03:38,  3.05s/batch, loss=2.56]

Epoch: 0, Iteration: 0, Loss: 2.5584449768066406


Epoch 0:   0%|          | 51/29571 [00:43<6:09:29,  1.33batch/s, loss=1.78]

Epoch: 0, Iteration: 50, Loss: 1.3251490592956543


Epoch 0:   0%|          | 101/29571 [01:24<6:37:48,  1.23batch/s, loss=1.46]

Epoch: 0, Iteration: 100, Loss: 0.8877701163291931


Epoch 0:   1%|          | 151/29571 [02:04<6:36:06,  1.24batch/s, loss=1.3] 

Epoch: 0, Iteration: 150, Loss: 0.7371338605880737


Epoch 0:   1%|          | 201/29571 [02:45<7:23:23,  1.10batch/s, loss=1.25]

Epoch: 0, Iteration: 200, Loss: 1.7367942333221436


Epoch 0:   1%|          | 251/29571 [03:25<7:03:07,  1.15batch/s, loss=1.24]

Epoch: 0, Iteration: 250, Loss: 1.0984268188476562


Epoch 0:   1%|          | 301/29571 [04:04<6:39:46,  1.22batch/s, loss=1.2] 

Epoch: 0, Iteration: 300, Loss: 0.8813862800598145


Epoch 0:   1%|          | 351/29571 [04:44<6:43:25,  1.21batch/s, loss=1.21]

Epoch: 0, Iteration: 350, Loss: 4.4558868408203125


Epoch 0:   1%|▏         | 401/29571 [05:24<6:54:31,  1.17batch/s, loss=1.23]

Epoch: 0, Iteration: 400, Loss: 1.0085225105285645


Epoch 0:   2%|▏         | 451/29571 [06:05<6:55:19,  1.17batch/s, loss=1.24]

Epoch: 0, Iteration: 450, Loss: 0.976171612739563


Epoch 0:   2%|▏         | 501/29571 [06:44<6:20:54,  1.27batch/s, loss=1.24]

Epoch: 0, Iteration: 500, Loss: 1.2045652866363525


Epoch 0:   2%|▏         | 551/29571 [07:25<6:38:44,  1.21batch/s, loss=1.22]

Epoch: 0, Iteration: 550, Loss: 0.938727617263794


Epoch 0:   2%|▏         | 601/29571 [08:04<6:05:54,  1.32batch/s, loss=1.21]

Epoch: 0, Iteration: 600, Loss: 1.3535099029541016


Epoch 0:   2%|▏         | 651/29571 [08:45<6:29:02,  1.24batch/s, loss=1.2] 

Epoch: 0, Iteration: 650, Loss: 0.9389598965644836


Epoch 0:   2%|▏         | 701/29571 [09:26<6:19:49,  1.27batch/s, loss=1.2] 

Epoch: 0, Iteration: 700, Loss: 1.2568880319595337


Epoch 0:   3%|▎         | 751/29571 [10:06<6:32:40,  1.22batch/s, loss=1.18]

Epoch: 0, Iteration: 750, Loss: 0.9637616872787476


Epoch 0:   3%|▎         | 801/29571 [10:47<6:28:41,  1.23batch/s, loss=1.17]

Epoch: 0, Iteration: 800, Loss: 0.9904544353485107


Epoch 0:   3%|▎         | 851/29571 [11:28<6:52:30,  1.16batch/s, loss=1.15]

Epoch: 0, Iteration: 850, Loss: 0.6422180533409119


Epoch 0:   3%|▎         | 901/29571 [12:07<6:37:34,  1.20batch/s, loss=1.14]

Epoch: 0, Iteration: 900, Loss: 0.84706711769104


Epoch 0:   3%|▎         | 951/29571 [12:48<6:05:59,  1.30batch/s, loss=1.12]

Epoch: 0, Iteration: 950, Loss: 0.7746067047119141


Epoch 0:   3%|▎         | 1001/29571 [13:27<5:22:03,  1.48batch/s, loss=1.11]

Epoch: 0, Iteration: 1000, Loss: 0.9958347082138062


Epoch 0:   4%|▎         | 1051/29571 [14:08<5:47:32,  1.37batch/s, loss=1.11]

Epoch: 0, Iteration: 1050, Loss: 0.7329617142677307


Epoch 0:   4%|▎         | 1101/29571 [14:49<6:21:51,  1.24batch/s, loss=1.09]

Epoch: 0, Iteration: 1100, Loss: 0.9973005056381226


Epoch 0:   4%|▍         | 1151/29571 [15:30<5:50:27,  1.35batch/s, loss=1.08]

Epoch: 0, Iteration: 1150, Loss: 0.6038447022438049


Epoch 0:   4%|▍         | 1201/29571 [16:09<5:56:42,  1.33batch/s, loss=1.08]

Epoch: 0, Iteration: 1200, Loss: 0.6395095586776733


Epoch 0:   4%|▍         | 1251/29571 [16:52<6:09:00,  1.28batch/s, loss=1.07]

Epoch: 0, Iteration: 1250, Loss: 0.6234492659568787


Epoch 0:   4%|▍         | 1301/29571 [17:32<6:11:08,  1.27batch/s, loss=1.07]

Epoch: 0, Iteration: 1300, Loss: 0.8990615606307983


Epoch 0:   5%|▍         | 1351/29571 [18:14<6:49:47,  1.15batch/s, loss=1.06]

Epoch: 0, Iteration: 1350, Loss: 0.6774499416351318


Epoch 0:   5%|▍         | 1401/29571 [18:55<6:26:56,  1.21batch/s, loss=1.06]

Epoch: 0, Iteration: 1400, Loss: 1.3947951793670654


Epoch 0:   5%|▍         | 1451/29571 [19:36<6:18:07,  1.24batch/s, loss=1.06]

Epoch: 0, Iteration: 1450, Loss: 1.0327324867248535


Epoch 0:   5%|▌         | 1501/29571 [20:15<6:19:32,  1.23batch/s, loss=1.07]

Epoch: 0, Iteration: 1500, Loss: 1.1637322902679443


Epoch 0:   5%|▌         | 1551/29571 [20:55<5:15:38,  1.48batch/s, loss=1.07]

Epoch: 0, Iteration: 1550, Loss: 1.3578410148620605


Epoch 0:   5%|▌         | 1601/29571 [21:35<6:18:56,  1.23batch/s, loss=1.07]

Epoch: 0, Iteration: 1600, Loss: 1.2591513395309448


Epoch 0:   6%|▌         | 1651/29571 [22:14<5:42:31,  1.36batch/s, loss=1.07]

Epoch: 0, Iteration: 1650, Loss: 0.7209855318069458


Epoch 0:   6%|▌         | 1701/29571 [22:55<5:59:38,  1.29batch/s, loss=1.06]

Epoch: 0, Iteration: 1700, Loss: 1.3535374402999878


Epoch 0:   6%|▌         | 1751/29571 [23:35<6:15:55,  1.23batch/s, loss=1.06]

Epoch: 0, Iteration: 1750, Loss: 0.8304822444915771


Epoch 0:   6%|▌         | 1801/29571 [24:17<5:35:47,  1.38batch/s, loss=1.06]

Epoch: 0, Iteration: 1800, Loss: 0.6131355166435242


Epoch 0:   6%|▋         | 1851/29571 [24:58<6:04:16,  1.27batch/s, loss=1.06]

Epoch: 0, Iteration: 1850, Loss: 1.0155150890350342


Epoch 0:   6%|▋         | 1901/29571 [25:38<5:29:32,  1.40batch/s, loss=1.05]

Epoch: 0, Iteration: 1900, Loss: 0.9218957424163818


Epoch 0:   7%|▋         | 1951/29571 [26:19<6:22:54,  1.20batch/s, loss=1.05]

Epoch: 0, Iteration: 1950, Loss: 0.5430458784103394


Epoch 0:   7%|▋         | 2001/29571 [27:00<6:01:20,  1.27batch/s, loss=1.04]

Epoch: 0, Iteration: 2000, Loss: 0.7996156811714172


Epoch 0:   7%|▋         | 2051/29571 [27:42<6:16:40,  1.22batch/s, loss=1.04]

Epoch: 0, Iteration: 2050, Loss: 0.7476477026939392


Epoch 0:   7%|▋         | 2151/29571 [29:04<6:37:35,  1.15batch/s, loss=1.04]

Epoch: 0, Iteration: 2150, Loss: 0.7688581347465515


Epoch 0:   7%|▋         | 2201/29571 [29:45<6:04:21,  1.25batch/s, loss=1.03]

Epoch: 0, Iteration: 2200, Loss: 0.9622167348861694


Epoch 0:   8%|▊         | 2251/29571 [30:26<5:55:37,  1.28batch/s, loss=1.03]

Epoch: 0, Iteration: 2250, Loss: 0.7215970754623413


Epoch 0:   8%|▊         | 2301/29571 [31:06<6:35:50,  1.15batch/s, loss=1.02]

Epoch: 0, Iteration: 2300, Loss: 0.827655553817749


Epoch 0:   8%|▊         | 2351/29571 [31:46<6:06:34,  1.24batch/s, loss=1.02]

Epoch: 0, Iteration: 2350, Loss: 1.04054856300354


Epoch 0:   8%|▊         | 2401/29571 [32:26<6:22:47,  1.18batch/s, loss=1.02]

Epoch: 0, Iteration: 2400, Loss: 1.562365174293518


Epoch 0:   8%|▊         | 2451/29571 [33:07<6:10:55,  1.22batch/s, loss=1.02]

Epoch: 0, Iteration: 2450, Loss: 0.9429277181625366


Epoch 0:   8%|▊         | 2501/29571 [33:48<5:39:09,  1.33batch/s, loss=1.02]

Epoch: 0, Iteration: 2500, Loss: 1.0264711380004883


Epoch 0:   9%|▊         | 2551/29571 [34:29<6:33:52,  1.14batch/s, loss=1.02]

Epoch: 0, Iteration: 2550, Loss: 0.7677050232887268


Epoch 0:   9%|▉         | 2601/29571 [35:09<6:01:30,  1.24batch/s, loss=1.01]

Epoch: 0, Iteration: 2600, Loss: 0.8180989027023315


Epoch 0:   9%|▉         | 2651/29571 [35:47<6:12:17,  1.21batch/s, loss=1.01]

Epoch: 0, Iteration: 2650, Loss: 0.7150247097015381


Epoch 0:   9%|▉         | 2701/29571 [36:28<6:12:54,  1.20batch/s, loss=1.01]

Epoch: 0, Iteration: 2700, Loss: 1.002429723739624


Epoch 0:   9%|▉         | 2751/29571 [37:07<5:10:00,  1.44batch/s, loss=1.01]

Epoch: 0, Iteration: 2750, Loss: 0.9159128069877625


Epoch 0:   9%|▉         | 2801/29571 [37:48<5:31:38,  1.35batch/s, loss=1.01]

Epoch: 0, Iteration: 2800, Loss: 0.6356573104858398


Epoch 0:  10%|▉         | 2851/29571 [38:28<5:41:36,  1.30batch/s, loss=1.01]

Epoch: 0, Iteration: 2850, Loss: 1.2588895559310913


Epoch 0:  10%|▉         | 2901/29571 [39:08<6:17:19,  1.18batch/s, loss=1.01]

Epoch: 0, Iteration: 2900, Loss: 0.7353414297103882


Epoch 0:  10%|▉         | 2951/29571 [39:49<6:01:45,  1.23batch/s, loss=1.01]

Epoch: 0, Iteration: 2950, Loss: 0.7811344861984253


Epoch 0:  10%|█         | 3001/29571 [40:30<6:27:04,  1.14batch/s, loss=1]   

Epoch: 0, Iteration: 3000, Loss: 0.997367799282074


Epoch 0:  10%|█         | 3051/29571 [41:12<6:18:19,  1.17batch/s, loss=1]

Epoch: 0, Iteration: 3050, Loss: 0.7586492896080017


Epoch 0:  10%|█         | 3101/29571 [41:53<6:08:30,  1.20batch/s, loss=1]

Epoch: 0, Iteration: 3100, Loss: 0.7806501984596252


Epoch 0:  11%|█         | 3151/29571 [42:34<6:07:49,  1.20batch/s, loss=1]

Epoch: 0, Iteration: 3150, Loss: 0.7751734256744385


Epoch 0:  11%|█         | 3201/29571 [43:14<5:39:03,  1.30batch/s, loss=0.998]

Epoch: 0, Iteration: 3200, Loss: 0.9292445778846741


Epoch 0:  11%|█         | 3251/29571 [43:55<6:17:13,  1.16batch/s, loss=0.996]

Epoch: 0, Iteration: 3250, Loss: 0.6140727996826172


Epoch 0:  11%|█         | 3301/29571 [44:35<6:20:47,  1.15batch/s, loss=0.996]

Epoch: 0, Iteration: 3300, Loss: 0.8150427341461182


Epoch 0:  11%|█▏        | 3351/29571 [45:16<5:43:20,  1.27batch/s, loss=0.994]

Epoch: 0, Iteration: 3350, Loss: 1.0017802715301514


Epoch 0:  12%|█▏        | 3401/29571 [45:55<5:50:41,  1.24batch/s, loss=0.991]

Epoch: 0, Iteration: 3400, Loss: 0.7302177548408508


Epoch 0:  12%|█▏        | 3451/29571 [46:34<4:54:33,  1.48batch/s, loss=0.988]

Epoch: 0, Iteration: 3450, Loss: 0.6424002051353455


Epoch 0:  12%|█▏        | 3501/29571 [47:14<6:03:51,  1.19batch/s, loss=0.985]

Epoch: 0, Iteration: 3500, Loss: 0.7271562814712524


Epoch 0:  12%|█▏        | 3551/29571 [47:56<5:59:56,  1.20batch/s, loss=0.982]

Epoch: 0, Iteration: 3550, Loss: 0.8018633723258972


Epoch 0:  12%|█▏        | 3601/29571 [48:36<5:11:00,  1.39batch/s, loss=0.981]

Epoch: 0, Iteration: 3600, Loss: 0.7707138061523438


Epoch 0:  12%|█▏        | 3651/29571 [49:15<5:33:09,  1.30batch/s, loss=0.983]

Epoch: 0, Iteration: 3650, Loss: 1.5262399911880493


Epoch 0:  13%|█▎        | 3701/29571 [49:56<5:47:30,  1.24batch/s, loss=0.99] 

Epoch: 0, Iteration: 3700, Loss: 1.3091509342193604


Epoch 0:  13%|█▎        | 3751/29571 [50:35<5:09:59,  1.39batch/s, loss=0.997]

Epoch: 0, Iteration: 3750, Loss: 1.5166168212890625


Epoch 0:  13%|█▎        | 3801/29571 [51:15<5:41:36,  1.26batch/s, loss=1]    

Epoch: 0, Iteration: 3800, Loss: 1.1218323707580566


Epoch 0:  13%|█▎        | 3851/29571 [51:56<6:08:45,  1.16batch/s, loss=1]

Epoch: 0, Iteration: 3850, Loss: 0.8289463520050049


Epoch 0:  13%|█▎        | 3901/29571 [52:37<5:35:50,  1.27batch/s, loss=1]

Epoch: 0, Iteration: 3900, Loss: 0.9247900247573853


Epoch 0:  13%|█▎        | 3951/29571 [53:17<6:18:39,  1.13batch/s, loss=1]

Epoch: 0, Iteration: 3950, Loss: 1.0130199193954468


Epoch 0:  14%|█▎        | 4001/29571 [54:00<6:11:55,  1.15batch/s, loss=1]

Epoch: 0, Iteration: 4000, Loss: 0.9894225001335144


Epoch 0:  14%|█▎        | 4051/29571 [54:40<6:01:38,  1.18batch/s, loss=1]

Epoch: 0, Iteration: 4050, Loss: 1.079021692276001


Epoch 0:  14%|█▍        | 4101/29571 [55:21<6:10:00,  1.15batch/s, loss=1]

Epoch: 0, Iteration: 4100, Loss: 0.954132616519928


Epoch 0:  14%|█▍        | 4151/29571 [56:00<5:17:38,  1.33batch/s, loss=1]

Epoch: 0, Iteration: 4150, Loss: 0.7721267342567444


Epoch 0:  14%|█▍        | 4201/29571 [56:41<5:10:49,  1.36batch/s, loss=1]

Epoch: 0, Iteration: 4200, Loss: 0.8453388214111328


Epoch 0:  14%|█▍        | 4251/29571 [57:20<6:10:03,  1.14batch/s, loss=1]

Epoch: 0, Iteration: 4250, Loss: 0.8521614670753479


Epoch 0:  15%|█▍        | 4301/29571 [58:00<5:14:10,  1.34batch/s, loss=0.999]

Epoch: 0, Iteration: 4300, Loss: 1.0069053173065186


Epoch 0:  15%|█▍        | 4351/29571 [58:41<5:45:22,  1.22batch/s, loss=0.997]

Epoch: 0, Iteration: 4350, Loss: 0.8003000020980835


Epoch 0:  15%|█▍        | 4401/29571 [59:20<5:00:02,  1.40batch/s, loss=0.996]

Epoch: 0, Iteration: 4400, Loss: 1.0505023002624512


Epoch 0:  15%|█▌        | 4451/29571 [1:00:01<5:55:51,  1.18batch/s, loss=0.995]

Epoch: 0, Iteration: 4450, Loss: 0.6809003949165344


Epoch 0:  15%|█▌        | 4501/29571 [1:00:41<5:18:44,  1.31batch/s, loss=0.994]

Epoch: 0, Iteration: 4500, Loss: 0.8076515793800354


Epoch 0:  15%|█▌        | 4551/29571 [1:01:21<5:44:50,  1.21batch/s, loss=0.993]

Epoch: 0, Iteration: 4550, Loss: 1.311490535736084


Epoch 0:  16%|█▌        | 4601/29571 [1:02:02<5:32:56,  1.25batch/s, loss=0.993]

Epoch: 0, Iteration: 4600, Loss: 0.8919007778167725


Epoch 0:  16%|█▌        | 4651/29571 [1:02:41<5:44:24,  1.21batch/s, loss=0.992]

Epoch: 0, Iteration: 4650, Loss: 0.9578607678413391


Epoch 0:  16%|█▌        | 4701/29571 [1:03:21<5:44:25,  1.20batch/s, loss=0.991]

Epoch: 0, Iteration: 4700, Loss: 0.6026515960693359


Epoch 0:  16%|█▌        | 4751/29571 [1:04:02<5:33:15,  1.24batch/s, loss=0.992]

Epoch: 0, Iteration: 4750, Loss: 0.9154365062713623


Epoch 0:  16%|█▌        | 4801/29571 [1:04:42<5:45:09,  1.20batch/s, loss=0.992]

Epoch: 0, Iteration: 4800, Loss: 1.0544462203979492


Epoch 0:  16%|█▋        | 4851/29571 [1:05:23<5:51:07,  1.17batch/s, loss=0.993]

Epoch: 0, Iteration: 4850, Loss: 1.155287742614746


Epoch 0:  17%|█▋        | 4901/29571 [1:06:03<5:03:03,  1.36batch/s, loss=0.994]

Epoch: 0, Iteration: 4900, Loss: 0.905718982219696


Epoch 0:  17%|█▋        | 4951/29571 [1:06:44<5:24:59,  1.26batch/s, loss=0.993]

Epoch: 0, Iteration: 4950, Loss: 0.8287535309791565


Epoch 0:  17%|█▋        | 5001/29571 [1:07:24<4:57:54,  1.37batch/s, loss=0.993]

Epoch: 0, Iteration: 5000, Loss: 0.9965054988861084


Epoch 0:  17%|█▋        | 5051/29571 [1:08:05<5:55:26,  1.15batch/s, loss=0.993]

Epoch: 0, Iteration: 5050, Loss: 1.1698986291885376


Epoch 0:  17%|█▋        | 5101/29571 [1:08:45<4:53:13,  1.39batch/s, loss=0.993]

Epoch: 0, Iteration: 5100, Loss: 1.1545238494873047


Epoch 0:  17%|█▋        | 5151/29571 [1:09:26<5:14:25,  1.29batch/s, loss=0.993]

Epoch: 0, Iteration: 5150, Loss: 1.0185304880142212


Epoch 0:  18%|█▊        | 5201/29571 [1:10:07<5:01:23,  1.35batch/s, loss=0.993]

Epoch: 0, Iteration: 5200, Loss: 0.905238687992096


Epoch 0:  18%|█▊        | 5251/29571 [1:10:47<5:46:41,  1.17batch/s, loss=0.992]

Epoch: 0, Iteration: 5250, Loss: 0.9193329215049744


Epoch 0:  18%|█▊        | 5301/29571 [1:11:26<5:22:10,  1.26batch/s, loss=0.991]

Epoch: 0, Iteration: 5300, Loss: 1.0775712728500366


Epoch 0:  18%|█▊        | 5351/29571 [1:12:06<5:43:49,  1.17batch/s, loss=0.99] 

Epoch: 0, Iteration: 5350, Loss: 0.7223323583602905


Epoch 0:  18%|█▊        | 5401/29571 [1:12:48<5:38:51,  1.19batch/s, loss=0.989]

Epoch: 0, Iteration: 5400, Loss: 0.900412380695343


Epoch 0:  18%|█▊        | 5451/29571 [1:13:28<5:36:17,  1.20batch/s, loss=0.989]

Epoch: 0, Iteration: 5450, Loss: 1.3376609086990356


Epoch 0:  19%|█▊        | 5501/29571 [1:14:08<5:11:05,  1.29batch/s, loss=0.989]

Epoch: 0, Iteration: 5500, Loss: 0.6887228488922119


Epoch 0:  19%|█▉        | 5551/29571 [1:14:48<6:00:33,  1.11batch/s, loss=0.989]

Epoch: 0, Iteration: 5550, Loss: 1.0272034406661987


Epoch 0:  19%|█▉        | 5601/29571 [1:15:30<5:54:14,  1.13batch/s, loss=0.989]

Epoch: 0, Iteration: 5600, Loss: 0.8808602094650269


Epoch 0:  19%|█▉        | 5651/29571 [1:16:10<5:55:55,  1.12batch/s, loss=0.988]

Epoch: 0, Iteration: 5650, Loss: 0.9387532472610474


Epoch 0:  19%|█▉        | 5701/29571 [1:16:48<5:07:21,  1.29batch/s, loss=0.988]

Epoch: 0, Iteration: 5700, Loss: 0.7103009223937988


Epoch 0:  19%|█▉        | 5751/29571 [1:17:29<5:56:49,  1.11batch/s, loss=0.987]

Epoch: 0, Iteration: 5750, Loss: 0.9138044118881226


Epoch 0:  20%|█▉        | 5801/29571 [1:18:09<5:29:01,  1.20batch/s, loss=0.986]

Epoch: 0, Iteration: 5800, Loss: 0.6680952310562134


Epoch 0:  20%|█▉        | 5851/29571 [1:18:49<5:40:02,  1.16batch/s, loss=0.985]

Epoch: 0, Iteration: 5850, Loss: 0.8722472190856934


Epoch 0:  20%|█▉        | 5901/29571 [1:19:30<4:52:55,  1.35batch/s, loss=0.984]

Epoch: 0, Iteration: 5900, Loss: 0.982546329498291


Epoch 0:  20%|██        | 5951/29571 [1:20:10<5:03:59,  1.29batch/s, loss=0.989]

Epoch: 0, Iteration: 5950, Loss: 1.3618931770324707


Epoch 0:  20%|██        | 6001/29571 [1:20:49<4:54:23,  1.33batch/s, loss=0.993]

Epoch: 0, Iteration: 6000, Loss: 1.3273968696594238


Epoch 0:  20%|██        | 6051/29571 [1:21:28<4:58:10,  1.31batch/s, loss=0.996]

Epoch: 0, Iteration: 6050, Loss: 1.4900819063186646


Epoch 0:  21%|██        | 6101/29571 [1:22:09<5:05:13,  1.28batch/s, loss=0.998]

Epoch: 0, Iteration: 6100, Loss: 1.3265788555145264


Epoch 0:  21%|██        | 6151/29571 [1:22:49<5:38:01,  1.15batch/s, loss=1]    

Epoch: 0, Iteration: 6150, Loss: 1.9955787658691406


Epoch 0:  21%|██        | 6201/29571 [1:23:30<5:51:31,  1.11batch/s, loss=1]

Epoch: 0, Iteration: 6200, Loss: 0.9365532994270325


Epoch 0:  21%|██        | 6251/29571 [1:24:11<5:26:17,  1.19batch/s, loss=1.01]

Epoch: 0, Iteration: 6250, Loss: 1.1350319385528564


Epoch 0:  21%|██▏       | 6301/29571 [1:24:50<5:32:04,  1.17batch/s, loss=1.01]

Epoch: 0, Iteration: 6300, Loss: 1.1146891117095947


Epoch 0:  21%|██▏       | 6351/29571 [1:25:29<4:49:10,  1.34batch/s, loss=1.01]

Epoch: 0, Iteration: 6350, Loss: 1.219698190689087


Epoch 0:  22%|██▏       | 6401/29571 [1:26:12<5:36:18,  1.15batch/s, loss=1.01]

Epoch: 0, Iteration: 6400, Loss: 1.0182439088821411


Epoch 0:  22%|██▏       | 6451/29571 [1:26:53<5:43:00,  1.12batch/s, loss=1.01]

Epoch: 0, Iteration: 6450, Loss: 1.2695486545562744


Epoch 0:  22%|██▏       | 6501/29571 [1:27:34<5:20:04,  1.20batch/s, loss=1.01]

Epoch: 0, Iteration: 6500, Loss: 1.0675909519195557


Epoch 0:  22%|██▏       | 6551/29571 [1:28:14<5:10:33,  1.24batch/s, loss=1.01]

Epoch: 0, Iteration: 6550, Loss: 0.962177038192749


Epoch 0:  22%|██▏       | 6601/29571 [1:28:53<4:28:12,  1.43batch/s, loss=1.01]

Epoch: 0, Iteration: 6600, Loss: 1.0762782096862793


Epoch 0:  22%|██▏       | 6651/29571 [1:29:35<5:43:54,  1.11batch/s, loss=1.01]

Epoch: 0, Iteration: 6650, Loss: 1.2094782590866089


Epoch 0:  23%|██▎       | 6701/29571 [1:30:17<5:36:21,  1.13batch/s, loss=1.01]

Epoch: 0, Iteration: 6700, Loss: 0.9089945554733276


Epoch 0:  23%|██▎       | 6751/29571 [1:30:57<5:22:01,  1.18batch/s, loss=1.01]

Epoch: 0, Iteration: 6750, Loss: 1.0517241954803467


Epoch 0:  23%|██▎       | 6801/29571 [1:31:38<5:16:08,  1.20batch/s, loss=1.01]

Epoch: 0, Iteration: 6800, Loss: 0.9959408044815063


Epoch 0:  23%|██▎       | 6851/29571 [1:32:18<4:48:18,  1.31batch/s, loss=1.01]

Epoch: 0, Iteration: 6850, Loss: 1.1867375373840332


Epoch 0:  23%|██▎       | 6901/29571 [1:33:00<5:29:39,  1.15batch/s, loss=1.02]

Epoch: 0, Iteration: 6900, Loss: 1.0225059986114502


Epoch 0:  24%|██▎       | 6951/29571 [1:33:41<5:27:03,  1.15batch/s, loss=1.02]

Epoch: 0, Iteration: 6950, Loss: 0.9654570817947388


Epoch 0:  24%|██▎       | 7001/29571 [1:34:23<4:45:32,  1.32batch/s, loss=1.02]

Epoch: 0, Iteration: 7000, Loss: 0.9936019778251648


Epoch 0:  24%|██▍       | 7051/29571 [1:35:04<5:09:44,  1.21batch/s, loss=1.02]

Epoch: 0, Iteration: 7050, Loss: 1.022780418395996


Epoch 0:  24%|██▍       | 7101/29571 [1:35:44<5:06:41,  1.22batch/s, loss=1.02]

Epoch: 0, Iteration: 7100, Loss: 1.3136519193649292


Epoch 0:  24%|██▍       | 7151/29571 [1:36:24<5:04:36,  1.23batch/s, loss=1.02]

Epoch: 0, Iteration: 7150, Loss: 1.0449111461639404


Epoch 0:  24%|██▍       | 7201/29571 [1:37:04<4:39:15,  1.34batch/s, loss=1.02]

Epoch: 0, Iteration: 7200, Loss: 1.201991319656372


Epoch 0:  25%|██▍       | 7251/29571 [1:37:44<4:43:21,  1.31batch/s, loss=1.01]

Epoch: 0, Iteration: 7250, Loss: 1.1133041381835938


Epoch 0:  25%|██▍       | 7301/29571 [1:38:26<4:59:00,  1.24batch/s, loss=1.01]

Epoch: 0, Iteration: 7300, Loss: 0.7504031658172607


Epoch 0:  25%|██▍       | 7351/29571 [1:39:07<4:52:08,  1.27batch/s, loss=1.01]

Epoch: 0, Iteration: 7350, Loss: 0.9096827507019043


Epoch 0:  25%|██▌       | 7401/29571 [1:39:46<5:08:09,  1.20batch/s, loss=1.01]

Epoch: 0, Iteration: 7400, Loss: 0.8868407607078552


Epoch 0:  25%|██▌       | 7451/29571 [1:40:26<5:11:13,  1.18batch/s, loss=1.01]

Epoch: 0, Iteration: 7450, Loss: 0.8563796281814575


Epoch 0:  25%|██▌       | 7501/29571 [1:41:05<4:43:36,  1.30batch/s, loss=1.01]

Epoch: 0, Iteration: 7500, Loss: 1.0367681980133057


Epoch 0:  26%|██▌       | 7551/29571 [1:41:47<4:55:34,  1.24batch/s, loss=1.01]

Epoch: 0, Iteration: 7550, Loss: 0.870927095413208


Epoch 0:  26%|██▌       | 7601/29571 [1:42:28<5:09:12,  1.18batch/s, loss=1.01]

Epoch: 0, Iteration: 7600, Loss: 1.0566192865371704


Epoch 0:  26%|██▌       | 7651/29571 [1:43:09<5:26:07,  1.12batch/s, loss=1.01]

Epoch: 0, Iteration: 7650, Loss: 0.8886857032775879


Epoch 0:  26%|██▌       | 7701/29571 [1:43:48<4:58:54,  1.22batch/s, loss=1.01]

Epoch: 0, Iteration: 7700, Loss: 0.7419942617416382


Epoch 0:  26%|██▌       | 7742/29571 [1:44:41<4:55:12,  1.23batch/s, loss=1.01]


KeyboardInterrupt: 