In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:00<00:00, 112MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 5.44MB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 57.1MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 23.7MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [3]:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)


in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 10)


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:01<00:00, 167MB/s]


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [4]:

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)


In [None]:
from tqdm import tqdm

num_epochs = 1

def train_one_epoch(model, optimizer, data_loader, device):
    model.train()


    pbar = tqdm(data_loader, total=len(data_loader), desc="Training Progress", dynamic_ncols=True)

    for images, labels in pbar:

        boxes = []
        for image in images:
            box = torch.tensor([[0, 0, 224, 224]], dtype=torch.float32)
            boxes.append(box)
        boxes = torch.stack(boxes)

        images = images.to(device)
        boxes = boxes.to(device)
        labels = labels.to(device)


        target = [{'boxes': boxes[i], 'labels': labels[i].unsqueeze(0)} for i in range(len(labels))]


        optimizer.zero_grad()

        loss_dict = model(images, target)

        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

        pbar.set_postfix(loss=losses.item())

    return losses.item()


for epoch in range(num_epochs):
    loss = train_one_epoch(model, optimizer, train_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss}")


Training Progress:  61%|██████▏   | 4608/7500 [1:57:47<1:14:22,  1.54s/it, loss=0.0089]

In [None]:

def evaluate(model, data_loader, device):
    model.eval()
    results = []
    with torch.no_grad():
        for images, labels in data_loader:

            boxes = []
            for image in images:
                box = torch.tensor([[0, 0, 224, 224]], dtype=torch.float32)
                boxes.append(box)
            boxes = torch.stack(boxes)


            images = images.to(device)
            boxes = boxes.to(device)
            labels = labels.to(device)

            target = [{'boxes': boxes[i], 'labels': labels[i]} for i in range(len(labels))]

            prediction = model(images)
            results.append(prediction)
    return results

evaluate(model, test_loader, device)


In [None]:
def visualize_predictions(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            predictions = model(images)

            image = images[0].cpu().numpy().transpose((1, 2, 0))
            plt.imshow(image)
            plt.axis('off')

            boxes = predictions[0]['boxes'].cpu().numpy()
            labels = predictions[0]['labels'].cpu().numpy()
            for box, label in zip(boxes, labels):
                x_min, y_min, x_max, y_max = box
                plt.gca().add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor='r', facecolor='none'))
                plt.text(x_min, y_min, f'{label}', color='r', fontsize=12)
            plt.show()
            break

visualize_predictions(model, test_loader, device)
