In [7]:
import torch
from torch import nn
from PIL import Image
from os import listdir
import pytorch_lightning as pl
from torchvision import transforms, models

In [None]:
# CONSTANTS

# load images and transform
TEST_PATH = './testing_images'
CLASSES = 200
# data preprocession
RESIZE_SIZE = 256
INPUT_SIZE = 224
# training hyperparameters
BATCH_SIZE = 64

In [8]:
# class to id (0, 1, ..., 199) mapping
id_to_class = {}
with open('classes.txt', 'r') as f:
    for label in f.readlines():
        class_id, name = label.split('.')
        id_to_class[int(class_id) - 1] = label

In [9]:
# define model class
class CNN(pl.LightningModule):
    def __init__(self, batch_size):
        super().__init__()
        self.pretrained = models.resnet50(pretrained=True)
        setattr(self.pretrained, 'fc', nn.Linear(2048, 200))
        self.criterion = nn.CrossEntropyLoss()
        self.hparams.batch_size = batch_size
    
    def forward(self, x):
        return self.pretrained(x)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(),
                              lr=0.01, momentum=0.9, weight_decay=1e-4)
        return [optimizer], []
    
    def training_step (self, train_batch, batch_idx):
        x, y = train_batch
        pred = self.forward(x)
        loss = self.criterion(pred, y)
        y_pred = torch.max(pred.data, 1).indices == y
        acc = y_pred.sum() / y.shape[0]
        self.log('Train Loss', loss, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        self.log('Train Acc', acc, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.forward(x)
        loss = self.criterion(pred, y)
        y_pred = torch.max(pred.data, 1).indices == y
        acc = y_pred.sum() / y.shape[0]
        self.log('Val Loss', loss, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        self.log('Val Acc', acc, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_epoch_end(self, outputs):
        torch.save(self.state_dict(), 'resnet50.pt')
    
    def train_dataloader(self):
        return DataLoader(train_dataset, shuffle=True,
                          batch_size=self.hparams.batch_size, num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(val_dataset, shuffle=True,
                          batch_size=self.hparams.batch_size, num_workers=4)

In [10]:
# define model
model = CNN(BATCH_SIZE).eval()
model.load_state_dict(torch.load('resnet50.pt'))
print(model)

CNN(
  (pretrained): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [11]:
# image transformations
crop = transforms.Compose([
    transforms.Resize((RESIZE_SIZE, RESIZE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.CenterCrop(INPUT_SIZE),
])

flip = transforms.RandomHorizontalFlip(p=1)

In [12]:
# test answer generation
with open('sample_answer.txt', 'r') as f1:
    with open('answer.txt', 'w') as f2:
        for line in f1.readlines():
            filename, label = line.split()
            img = Image.open(f'{TEST_PATH}/{filename}')
            images = torch.stack([crop(img)]).cuda()
            with torch.no_grad():
                pred = model(images).sum(axis = 0)
            f2.write(f'{filename} {id_to_class[torch.argmax(pred).item()]}')