In [1]:
import torch, wandb
from PIL import Image
from os import listdir
from torch import nn, optim
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torchvision import transforms, models
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from torch.utils.data import random_split, TensorDataset, DataLoader

In [2]:
# CONSTANTS

# load images and transform
TRAIN_PATH = './training_images'
CLASSES = 200
# data preprocession
RESIZE_SIZE = 256
INPUT_SIZE = 224
# training/validation split constants
DATA_SIZE = 3000
DATA_PER_CLASS = int(DATA_SIZE / CLASSES)
TRAIN_RATIO = 0.75
TRAIN_PER_CLASS = int(DATA_PER_CLASS * TRAIN_RATIO)
VAL_PER_CLASS = DATA_PER_CLASS - TRAIN_PER_CLASS
# training hyperparameters
BATCH_SIZE = 32
EPOCHS = 30

In [3]:
# class to id (0, 1, ..., 199) mapping
class_to_id = {}
with open('classes.txt', 'r') as f:
    for label in f.readlines():
        class_id, name = label.split('.')
        class_to_id[label[: -1]] = int(class_id) - 1

# filename to id mapping
file_to_id = {}
with open('training_labels.txt', 'r') as f:
    for line in f.readlines():
        filename, label = line.split()
        file_to_id[filename] = class_to_id[label]

In [4]:
# image transformations
crop = transforms.Compose([
    transforms.Resize((RESIZE_SIZE, RESIZE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.CenterCrop(INPUT_SIZE),
])

flip = transforms.RandomHorizontalFlip(p=1)

In [5]:
# create 2d list of data
dataset = [[] for _ in range(CLASSES)]

# load image and map to id
for img_file in listdir(TRAIN_PATH):
    img_id = file_to_id[img_file]
    img = Image.open(f'{TRAIN_PATH}/{img_file}')
    dataset[img_id].append(crop(img))

In [6]:
#initialize data
train_data, train_label, val_data, val_label = [], [], [], []

# for each group ID
for i in range(CLASSES):
    perm = torch.randperm(DATA_PER_CLASS)
    # add 11 original images
    train_data.extend([dataset[i][idx]
                       for idx in perm[: TRAIN_PER_CLASS]])
    # add 11 flip images
    train_data.extend([flip(dataset[i][idx])
                       for idx in perm[: TRAIN_PER_CLASS]])
    # 4 images go to validation
    val_data.extend([dataset[i][idx]
                     for idx in perm[TRAIN_PER_CLASS: ]])
    # add labels
    train_label.extend([i] * TRAIN_PER_CLASS * 2)
    val_label.extend([i] * VAL_PER_CLASS)

In [7]:
# list to dataset
train_data = torch.stack(train_data)
train_label = torch.LongTensor(train_label)
val_data = torch.stack(val_data)
val_label = torch.LongTensor(val_label)
train_dataset = TensorDataset(train_data, train_label)
val_dataset = TensorDataset(val_data, val_label)

In [8]:
# save datasets for further usage
torch.save(train_dataset, 'train_dataset.pt')
torch.save(val_dataset, 'val_dataset.pt')

In [9]:
# # load while using it later on
# train_dataset = torch.load('train_dataset.pt')
# val_dataset = torch.load('val_dataset.pt')

In [10]:
# define model class
class CNN(pl.LightningModule):
    def __init__(self, batch_size):
        super().__init__()
        self.pretrained = models.resnet50(pretrained=True).cuda()
        setattr(self.pretrained, 'fc', nn.Linear(2048, 200).cuda())
        self.criterion = nn.CrossEntropyLoss()
        self.hparams.batch_size = batch_size
    
    def forward(self, x):
        return self.pretrained(x)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(),
                              lr=0.01, momentum=0.9, weight_decay=1e-4)
        return [optimizer], []
    
    def training_step (self, train_batch, batch_idx):
        x, y = train_batch
        pred = self.forward(x)
        loss = self.criterion(pred, y)
        y_pred = torch.max(pred.data, 1).indices == y
        acc = y_pred.sum() / y.shape[0]
        self.log('Train Loss', loss, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        self.log('Train Acc', acc, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.forward(x)
        loss = self.criterion(pred, y)
        y_pred = torch.max(pred.data, 1).indices == y
        acc = y_pred.sum() / y.shape[0]
        self.log('Val Loss', loss, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        self.log('Val Acc', acc, on_step=True,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_epoch_end(self, outputs):
        torch.save(self.state_dict(), 'resnet50.pt')
    
    def train_dataloader(self):
        return DataLoader(train_dataset, shuffle=True,
                          batch_size=self.hparams.batch_size, num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(val_dataset, shuffle=True,
                          batch_size=self.hparams.batch_size, num_workers=4)

In [11]:
# define model
model = CNN(BATCH_SIZE)
print(model)

CNN(
  (pretrained): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

In [12]:
#set logger, monitor
wandb_logger = WandbLogger(project='vrdl_hw1', entity='udchen')
lr_monitor = LearningRateMonitor(logging_interval='step')

# define trainer
trainer = Trainer(
    gpus=1,
    logger=wandb_logger,
    track_grad_norm=2,
    callbacks=[lr_monitor],
    log_every_n_steps=10,
    max_epochs=EPOCHS,
    stochastic_weight_avg=True,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [13]:
# fit model
trainer.fit(model, model.train_dataloader(), model.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mudchen[0m (use `wandb login --relogin` to force relogin)



  | Name       | Type             | Params
------------------------------------------------
0 | pretrained | ResNet           | 23.9 M
1 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
95.671    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [14]:
torch.save(model.state_dict(), 'resnet50.pt')